diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 1d50c2b879f4acc6e2ceaed90ede7c2a502c9e8f..06b3fbbd4e6c50e4734b489c1a7170b745dab809 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -13,3 +13,35 @@ mindscience/MindChemistry/mindchemistry/so2_conv/wigner.py:wigner_D mindscience/MindSPONGE/src/sponge/system/modelling/mol2_parser.py:mol2parser mindscience/MindSPONGE/src/sponge/system/modelling/hadder.py:add_hydrogen mindscience/MindSPONGE/src/sponge/system/molecule/molecule.py:build_system + +## MindSPONGE Grasp/multimer-parallel +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/residue.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/residue.py:add_atom +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_data.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/basic_modules.py:construct +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_data.py:random_crop_to_size +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/pairs.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/data/parsers.py:parse_mmcif +mindscience/MindSPONGE/applications/research/Grasp/data/preprocess.py:non_ensemble +mindscience/MindSPONGE/applications/research/Grasp/data/preprocess.py:ensemble +mindscience/MindSPONGE/applications/research/Grasp/utils_infer.py:grasp_infer +mindscience/MindSPONGE/applications/research/Grasp/restraint_sample.py:generate_interface_and_restraints +mindscience/MindSPONGE/applications/research/Grasp/model/fold.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/model/assessment.py:construct +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/residue.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/data/data_transform.py:atom37_to_torsion_angles +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/data/data_transform.py:atom37_to_frames +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/structure_violations.py:frame_aligned_point_error_map +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/structure_violations.py:frame_aligned_point_error +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_data.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/nn_arch.py:construct +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/basic_modules.py:construct +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/pairs.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/potential/forcefield.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/mindsponge1/partition/grids.py:__init__ +mindscience/MindSPONGE/applications/research/Grasp/data/preprocess.py:non_ensemble +mindscience/MindSPONGE/applications/research/Grasp/data/preprocess.py:ensemble +mindscience/MindSPONGE/applications/research/Grasp/module/template_embedding_new.py:construct +mindscience/MindSPONGE/applications/research/Grasp/utils_infer.py:filter_restraints +mindscience/MindSPONGE/applications/research/Grasp/utils_infer.py:grasp_infer_quick +mindscience/MindSPONGE/applications/research/Grasp/utils_infer.py:grasp_infer \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/README.md b/MindSPONGE/applications/research/Grasp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..37a3f0c2a1b1db66a11338adcc8993881290c01c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/README.md @@ -0,0 +1,764 @@ +# Multimer多卡并行推理 + +## 1 环境依赖 + +### 1.1 固件驱动及CANN包版本 + +```bash +# cat /usr/local/Ascend/ascend-toolkit/latest/version.cfg +runtime_running_version=[7.3.0.1.231:8.0.RC2] +compiler_running_version=[7.3.0.1.231:8.0.RC2] +hccl_running_version=[7.3.0.1.231:8.0.RC2] +opp_running_version=[7.3.0.1.231:8.0.RC2] +toolkit_running_version=[7.3.0.1.231:8.0.RC2] +aoe_running_version=[7.3.0.1.231:8.0.RC2] +ncs_running_version=[7.3.0.1.231:8.0.RC2] +opp_kernel_running_version=[7.3.0.1.231:8.0.RC2] +toolkit_upgrade_version=[7.3.0.1.231:8.0.RC2] +aoe_upgrade_version=[7.3.0.1.231:8.0.RC2] +ncs_upgrade_version=[7.3.0.1.231:8.0.RC2] +opp_kernel_upgrade_version=[7.3.0.1.231:8.0.RC2] +opp_upgrade_version=[7.3.0.1.231:8.0.RC2] +runtime_upgrade_version=[7.3.0.1.231:8.0.RC2] +compiler_upgrade_version=[7.3.0.1.231:8.0.RC2] +hccl_upgrade_version=[7.3.0.1.231:8.0.RC2] +runtime_installed_version=[7.0.0.5.242:7.0.RC1][7.1.0.3.220:7.0.0][7.3.0.1.231:8.0.RC2] +compiler_installed_version=[7.0.0.5.242:7.0.RC1][7.1.0.3.220:7.0.0][7.3.0.1.231:8.0.RC2] +opp_installed_version=[7.0.0.5.242:7.0.RC1][7.1.0.3.220:7.0.0][7.3.0.1.231:8.0.RC2] +toolkit_installed_version=[7.0.0.5.242:7.0.RC1][7.1.0.3.220:7.0.0][7.3.0.1.231:8.0.RC2] +aoe_installed_version=[7.0.0.5.242:7.0.RC1][7.1.0.3.220:7.0.0][7.3.0.1.231:8.0.RC2] +ncs_installed_version=[7.0.0.5.242:7.0.RC1][7.1.0.3.220:7.0.0][7.3.0.1.231:8.0.RC2] +opp_kernel_installed_version=[7.2.T7.0.B121:8.0.RC1.alpha002][7.3.0.1.231:8.0.RC2] +hccl_installed_version=[7.3.0.1.231:8.0.RC2] + +``` + +### 1.2 conda环境依赖 + +```bash +# source activate python310 && pip list +absl-py==2.1.0 +aiohappyeyeballs==2.4.4 +aiohttp==3.11.11 +aiosignal==1.3.2 +anyio==4.8.0 +ascendebug @ file:///usr/local/Ascend/ascend-toolkit/8.0.RC2/toolkit/tools/ascendebug-0.1.0-py3-none-any.whl +asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work +astunparse @ file:///home/conda/feedstock_root/build_artifacts/astunparse_1736248061654/work +async-timeout==5.0.1 +attrs==25.1.0 +auto-tune @ file:///root/selfgz130520532488/compiler/lib64/auto_tune-0.1.0-py3-none-any.whl +bio==1.7.1 +biopython==1.81 +biothings_client==0.4.1 +biotite==0.40.0 +Bottleneck @ file:///croot/bottleneck_1731058648584/work +certifi==2024.12.14 +charset-normalizer==3.4.1 +click==8.1.8 +cloudpickle==3.1.1 +contourpy==1.3.1 +cycler==0.12.1 +dataclasses==0.6 +dataflow @ file:///root/selfgz130520532488/compiler/lib64/dataflow-0.0.1-py3-none-any.whl +datasets==2.18.0 +decorator==5.1.1 +descriptastorus==2.6.1 +dill==0.3.8 +exceptiongroup==1.2.2 +filelock==3.17.0 +fonttools==4.56.0 +frozenlist==1.5.0 +fsspec==2024.2.0 +ftfy==6.3.1 +glob2==0.7 +gprofiler-official==1.0.0 +h11==0.14.0 +h5py==3.12.1 +hccl @ file:///root/selfgz132073717241/hccl/lib64/hccl-0.1.0-py3-none-any.whl +hccl-parser @ file:///usr/local/Ascend/ascend-toolkit/8.0.RC2/toolkit/tools/hccl_parser-0.1-py3-none-any.whl +httpcore==1.0.7 +httpx==0.28.1 +huggingface-hub==0.27.1 +idna==3.10 +jieba==0.42.1 +Jinja2==3.1.5 +joblib==1.4.2 +kiwisolver==1.4.8 +llm-datadist @ file:///root/selfgz130520532488/compiler/lib64/llm_datadist-0.0.1-py3-none-any.whl +llm-engine @ file:///root/selfgz130520532488/compiler/lib64/llm_engine-0.0.1-py3-none-any.whl +MarkupSafe==3.0.2 +matplotlib==3.10.0 +mindformers==1.3.2 +mindpet==1.0.4 +mindsponge_ascend @ file:///nfs/grp/gyqlab/konglp/workspace/multimer_grasp_v11_0430_bac/multimer_grasp_v11_0430_bac/mindscience/MindSPONGE/output/mindsponge_ascend-1.0.0rc2-py3-none-any.whl#sha256=83c220d14ec130a8179def65221617164b66e57cda8d620be46eb80270ba44a9 +mindspore @ https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.5.0/MindSpore/unified/aarch64/mindspore-2.5.0-cp310-cp310-linux_aarch64.whl#sha256=1116fd666a059f0480deccd6af04f5e9fe9c019fa88df24a51b0e0fe3c2e55da +ml_dtypes==0.5.1 +mpmath==1.3.0 +msadvisor @ file:///usr/local/Ascend/ascend-toolkit/8.0.RC2/tools/msadvisor/python/msadvisor-1.0.0-cp37-abi3-linux_aarch64.whl +msgpack==1.1.0 +multidict==6.1.0 +multiprocess==0.70.16 +mygene==3.2.2 +networkx==3.4.2 +nltk==3.9.1 +numexpr @ file:///croot/numexpr_1730215942651/work +numpy==1.23.4 +op-compile-tool @ file:///root/selfgz130520532488/compiler/lib64/op_compile_tool-0.1.0-py3-none-any.whl +op-gen @ file:///usr/local/Ascend/ascend-toolkit/8.0.RC2/toolkit/tools/op_gen-0.1-py3-none-any.whl +op-test-frame @ file:///usr/local/Ascend/ascend-toolkit/8.0.RC2/toolkit/tools/op_test_frame-0.1-py3-none-any.whl +opc-tool @ file:///root/selfgz130520532488/compiler/lib64/opc_tool-0.1.0-py3-none-any.whl +opencv-python-headless==4.11.0.86 +packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1733203243479/work +pandas @ file:///croot/pandas_1732735105235/work/dist/pandas-2.2.3-cp310-cp310-linux_aarch64.whl#sha256=ce019667128a6de8bd8a2994b4bae9691713b9c98906420f2b7dedb0a993963a +pandas-flavor==0.6.0 +pillow @ file:///croot/pillow_1734430599218/work +platformdirs==4.3.6 +pooch==1.8.2 +propcache==0.2.1 +protobuf==3.19.1 +psutil==6.1.1 +pyarrow==12.0.1 +pyarrow-hotfix==0.6 +pyparsing==3.2.1 +python-dateutil @ file:///croot/python-dateutil_1716495745266/work +pytz @ file:///croot/pytz_1713974315080/work +PyYAML==6.0.2 +rdkit==2024.9.4 +regex==2024.11.6 +requests==2.32.3 +rouge-chinese==1.0.3 +safetensors @ file:///croot/safetensors_1732227620007/work +schedule-search @ file:///root/selfgz130520532488/compiler/lib64/schedule_search-0.1.0-py3-none-any.whl +scikit-learn==1.6.1 +scipy==1.13.1 +sentencepiece==0.2.0 +setproctitle==1.3.4 +six @ file:///tmp/build/80754af9/six_1644875935023/work +sniffio==1.3.1 +sympy==1.13.3 +te @ file:///root/selfgz130520532488/compiler/lib64/te-0.4.0-py3-none-any.whl +threadpoolctl==3.5.0 +tiktoken==0.8.0 +tokenizers==0.15.0 +tornado==6.4.2 +tqdm==4.67.1 +typing_extensions==4.12.2 +tzdata @ file:///croot/python-tzdata_1690578112552/work +urllib3==2.3.0 +wcwidth==0.2.13 +xarray==2024.7.0 +xxhash==3.5.0 +yarl==1.18.3 + +``` + +#### mpirun版本 + +```bash +mpirun (Open MPI) 4.1.2 +``` + +## 2 运行 + +### 2.1 Multimer多卡推理 + +```bash +bash infer_main_parallel.sh 0,1,2,3,4,5,6,7 8064 "./5JDS.pkl;;./step_8000.ckpt;1;1" +``` + +1. 0,1,2,3,4,5,6,7 代表任意device_id +2. 8064 代表序列长度 +3. "./5JDS.pkl;;./step_8000.ckpt;1;1" 字符串包括五个参数输入,分别是raw_feat、restr(可能为空,分号连续)、ckpt_path、iter和num_recycle。例如上述字符串代表的含义如下: + 1. raw_feat="./5JDS.pkl" + 2. restr="None" + 3. ckpt_path="./step_8000.ckpt" + 4. iter=1 + 5. num_cycle=1 + +```shell +# 结果日志,pdb文件保存在./compare_with_parallel/test4_8064_iter1_recycle10_graph_parallel.pdb +start recycle_cond +recycle 1 diff: 58.07871833571992 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 2 diff: 8.910383957501509 +end recycle_cond: True +--------------------start---------------------- +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-19:29:58.873.106 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 1, free memory : 33637916672, real free : 33598472192, not free : 39444480. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-19:30:19.577.569 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 2, free memory : 23910161920, real free : 23899144192, not free : 11017728. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-19:31:04.954.248 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 3, free memory : 20992928256, real free : 20967325696, not free : 25602560. +--------------------end------------------------ +start recycle_cond +recycle 3 diff: 1.9637068183177169 +end recycle_cond: True +--------------------start---------------------- +[WARNING] DEVICE(1020081,fff400e05120,python):2025-03-03-19:47:05.102.537 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_vmm_adapter.cc:176] MmapDeviceMem] Mapped too much memory, physical_handle_size_ : 29696, max_size : 62277025792. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-19:47:09.315.342 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 4, free memory : 46579264000, real free : 46628077568, not free : 0. +--------------------end------------------------ +start recycle_cond +recycle 4 diff: 1.3888285949764172 +end recycle_cond: True +--------------------start---------------------- +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:06:17.195.763 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 5, free memory : 41300996096, real free : 41305505792, not free : 0. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:08:13.866.614 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 6, free memory : 24988031488, real free : 24954011648, not free : 34019840. +--------------------end------------------------ +start recycle_cond +recycle 5 diff: 10.066165713126406 +end recycle_cond: True +--------------------start---------------------- +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:25:22.280.240 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 7, free memory : 45445552128, real free : 45470449664, not free : 0. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:27:17.831.470 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 8, free memory : 24988032512, real free : 24956108800, not free : 31923712. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:29:01.096.451 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 9, free memory : 20207269376, real free : 20199768064, not free : 7501312. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:29:22.670.183 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 10, free memory : 20810024960, real free : 20791164928, not free : 18860032. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:29:24.540.909 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 11, free memory : 16676098048, real free : 16680747008, not free : 0. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:29:44.289.865 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 12, free memory : 20810024960, real free : 20803747840, not free : 6277120. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:29:46.178.452 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 13, free memory : 16676098048, real free : 16680747008, not free : 0. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:30:05.953.177 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 14, free memory : 20810024960, real free : 20799553536, not free : 10471424. +--------------------end------------------------ +start recycle_cond +recycle 6 diff: 3.656605440009259 +end recycle_cond: True +--------------------start---------------------- +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:44:36.021.703 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 15, free memory : 43995280384, real free : 44025511936, not free : 0. +[WARNING] PRE_ACT(1020081,fff400e05120,python):2025-03-03-20:46:31.630.084 [mindspore/ccsrc/backend/common/mem_reuse/abstract_dynamic_mem_pool.cc:1036] FreeIdleMemsByEagerFree] Eager free count : 16, free memory : 25020546048, real free : 24991760384, not free : 28785664. +--------------------end------------------------ +start recycle_cond +recycle 7 diff: 3.186314201691005 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 8 diff: 1.2131272309106085 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 9 diff: 0.9297680422422511 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +[WARNING] CORE(1020081,ffffa38ab020,python):2025-03-03-22:00:39.424.233 [mindspore/core/include/ir/base_tensor.h:452] data] Try to alloca a large memory, size is:8323596288 + ===================== pdb_path ==================== ./compare_with_parallel/test4_8064_iter1_recycle10_graph_parallel.pdb +Filter Restraints Iteration 1 ============================================= +Breakage info ========== +Break number: 0, Max neighbour CA dist: 4.078125 + +Recall info============= +Stop iteration: RemoveThre,Converged,LastIter +Inference done! +time cost: 13111.61140203476 +``` + +### 2.2 Grasp_7R94_多卡推理 + +```bash +# 由于7R94.pkl对应序列3700+,因此padding至4096. +bash infer_main_parallel.sh 0,1,2,3,4,5,6,7 4096 "./features.pkl;./restr_5perc.pkl;step_14000.ckpt;5;20" +``` + +1. 0,1,2,3,4,5,6,7 代表任意device_id +2. 4096 代表序列长度 +3. "./features.pkl;./restr_5perc.pkl;step_14000.ckpt;5;20"字符串包括五个参数输入,分别是raw_feat、restr(可能为空,分号连续)、ckpt_path、iter和num_recycle。例如上述字符串代表的含义如下: + 1. raw_feat="./features.pkl" + 2. restr="./restr_5perc.pkl" + 3. ckpt_path="./step_14000.ckpt" + 4. iter=5 + 5. num_cycle=20 + +```shell +# seed=9 结果日志 +At least 38 restraints will be used in the final iteration +iter is 5 +[WARNING] CORE(2128692,ffff907d5020,python):2025-03-10-10:06:19.623.866 [mindspore/core/include/ir/base_tensor.h:85] NewData] Try to alloca a large memory, size is:4294967296 +num_recycle is 20 +msa_feat_sum 3841181.6750109335 +start recycle_cond +recycle 0 diff: 0.0001 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 1 diff: 78.62324050630868 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 2 diff: 25.586854637566837 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 3 diff: 8.839741836685704 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 4 diff: 2.436669909107999 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 5 diff: 3.358055246987672 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 6 diff: 4.751788874477254 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 7 diff: 2.8444712162684724 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 8 diff: 1.592084565769719 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 9 diff: 0.8363213934548326 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 10 diff: 0.6078216719909308 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 11 diff: 0.4431440625287018 +end recycle_cond: False +early stop: 11 + ===================== pdb_path ==================== ./compare_with_parallel/test6_4096_iter1_recycle20_graph_parallel.pdb +Filter Restraints Iteration 1 ============================================= +inter-residue restraints: 189(189 inter-chain + 0 intra-chain) +Inter-chain restraints +Included! Satisfied! A19/conf84.81/nbdist_avg_ca3.88<==>F477/conf53.87/nbdist_avg_ca4.15/dist_cb18.94, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A20/conf81.57/nbdist_avg_ca3.73<==>F481/conf49.65/nbdist_avg_ca3.65/dist_cb22.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A21/conf79.47/nbdist_avg_ca3.42<==>F611/conf62.57/nbdist_avg_ca3.73/dist_cb22.17, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A26/conf77.43/nbdist_avg_ca3.93<==>F477/conf53.87/nbdist_avg_ca4.15/dist_cb21.88, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A43/conf63.87/nbdist_avg_ca3.75<==>C370/conf78.42/nbdist_avg_ca3.93/dist_cb17.47, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A52/conf53.33/nbdist_avg_ca3.96<==>B271/conf74.52/nbdist_avg_ca3.88/dist_cb24.81, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A52/conf53.33/nbdist_avg_ca3.96<==>F466/conf68.76/nbdist_avg_ca3.82/dist_cb15.82, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A52/conf53.33/nbdist_avg_ca3.96<==>F473/conf70.79/nbdist_avg_ca3.82/dist_cb19.02, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A54/conf65.74/nbdist_avg_ca3.89<==>F467/conf66.28/nbdist_avg_ca3.92/dist_cb15.35, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A58/conf72.73/nbdist_avg_ca3.98<==>C293/conf77.78/nbdist_avg_ca3.86/dist_cb18.69, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A58/conf72.73/nbdist_avg_ca3.98<==>F477/conf53.87/nbdist_avg_ca4.15/dist_cb17.69, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A61/conf74.27/nbdist_avg_ca4.30<==>C293/conf77.78/nbdist_avg_ca3.86/dist_cb15.40, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A62/conf53.55/nbdist_avg_ca4.80<==>F486/conf71.07/nbdist_avg_ca3.86/dist_cb15.84, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A62/conf53.55/nbdist_avg_ca4.80<==>F495/conf55.37/nbdist_avg_ca3.75/dist_cb18.12, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A66/conf63.41/nbdist_avg_ca3.98<==>F459/conf64.33/nbdist_avg_ca3.70/dist_cb23.83, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A68/conf76.13/nbdist_avg_ca3.83<==>B322/conf89.01/nbdist_avg_ca3.85/dist_cb18.78, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A69/conf75.64/nbdist_avg_ca3.69<==>B283/conf87.03/nbdist_avg_ca3.81/dist_cb21.70, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A70/conf75.34/nbdist_avg_ca3.69<==>F484/conf63.43/nbdist_avg_ca3.91/dist_cb24.56, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A79/conf82.90/nbdist_avg_ca4.00<==>B291/conf77.93/nbdist_avg_ca3.84/dist_cb24.84, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A79/conf82.90/nbdist_avg_ca4.00<==>B327/conf84.36/nbdist_avg_ca3.76/dist_cb22.64, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A91/conf75.50/nbdist_avg_ca3.83<==>F502/conf66.05/nbdist_avg_ca3.92/dist_cb24.20, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! A93/conf70.61/nbdist_avg_ca3.85<==>F425/conf74.94/nbdist_avg_ca3.86/dist_cb26.58, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A105/conf83.63/nbdist_avg_ca3.86<==>F611/conf62.57/nbdist_avg_ca3.73/dist_cb24.16, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A132/conf67.23/nbdist_avg_ca4.91<==>F477/conf53.87/nbdist_avg_ca4.15/dist_cb15.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A132/conf67.23/nbdist_avg_ca4.91<==>F611/conf62.57/nbdist_avg_ca3.73/dist_cb20.30, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A134/conf87.77/nbdist_avg_ca4.09<==>F477/conf53.87/nbdist_avg_ca4.15/dist_cb20.44, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! A147/conf83.30/nbdist_avg_ca3.86<==>G410/conf72.35/nbdist_avg_ca3.89/dist_cb106.69, range: 0-25.0, rm_score 76.6875, rm_thre 0.0 +Included! Satisfied! A181/conf80.56/nbdist_avg_ca3.93<==>B283/conf87.03/nbdist_avg_ca3.81/dist_cb18.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A189/conf86.43/nbdist_avg_ca3.94<==>B267/conf78.14/nbdist_avg_ca3.98/dist_cb22.73, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A190/conf87.19/nbdist_avg_ca3.83<==>B372/conf78.83/nbdist_avg_ca3.83/dist_cb22.34, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A193/conf82.82/nbdist_avg_ca3.83<==>B114/conf77.64/nbdist_avg_ca3.83/dist_cb12.43, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A196/conf75.25/nbdist_avg_ca3.89<==>B186/conf86.98/nbdist_avg_ca3.80/dist_cb18.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A196/conf75.25/nbdist_avg_ca3.89<==>B372/conf78.83/nbdist_avg_ca3.83/dist_cb17.92, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A200/conf75.59/nbdist_avg_ca3.82<==>B16/conf81.36/nbdist_avg_ca3.79/dist_cb19.66, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! A201/conf72.03/nbdist_avg_ca3.83<==>C300/conf85.81/nbdist_avg_ca3.82/dist_cb30.20, range: 0-25.0, rm_score 0.203125, rm_thre 0.0 +Included! Satisfied! A204/conf75.94/nbdist_avg_ca3.92<==>B183/conf88.40/nbdist_avg_ca3.79/dist_cb18.00, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A205/conf76.10/nbdist_avg_ca3.87<==>B282/conf87.93/nbdist_avg_ca3.92/dist_cb14.29, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A210/conf82.53/nbdist_avg_ca3.86<==>C320/conf85.81/nbdist_avg_ca3.85/dist_cb24.11, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A214/conf86.91/nbdist_avg_ca3.85<==>B370/conf78.22/nbdist_avg_ca4.05/dist_cb24.73, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! A217/conf86.06/nbdist_avg_ca3.89<==>F596/conf76.15/nbdist_avg_ca3.94/dist_cb61.78, range: 0-25.0, rm_score 31.78125, rm_thre 0.0 +Excluded! Violated! A233/conf75.74/nbdist_avg_ca3.95<==>E70/conf85.95/nbdist_avg_ca3.74/dist_cb121.50, range: 0-25.0, rm_score 91.5, rm_thre 0.0 +Included! Satisfied! A235/conf74.96/nbdist_avg_ca3.81<==>B366/conf82.62/nbdist_avg_ca3.78/dist_cb21.67, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! A237/conf78.12/nbdist_avg_ca3.82<==>D328/conf79.35/nbdist_avg_ca3.69/dist_cb70.62, range: 0-25.0, rm_score 40.625, rm_thre 0.0 +Included! Satisfied! A246/conf71.49/nbdist_avg_ca4.10<==>C280/conf89.75/nbdist_avg_ca3.84/dist_cb15.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A246/conf71.49/nbdist_avg_ca4.10<==>C326/conf76.01/nbdist_avg_ca3.85/dist_cb9.06, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A252/conf84.47/nbdist_avg_ca3.88<==>B122/conf86.12/nbdist_avg_ca3.78/dist_cb23.22, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A262/conf91.78/nbdist_avg_ca3.79<==>B284/conf87.21/nbdist_avg_ca3.82/dist_cb22.34, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! A286/conf86.61/nbdist_avg_ca3.84<==>H500/conf76.92/nbdist_avg_ca3.83/dist_cb81.81, range: 0-25.0, rm_score 51.8125, rm_thre 0.0 +Excluded! Violated! A325/conf88.42/nbdist_avg_ca3.95<==>G503/conf70.46/nbdist_avg_ca3.86/dist_cb113.06, range: 0-25.0, rm_score 83.0625, rm_thre 0.0 +Excluded! Violated! A339/conf83.98/nbdist_avg_ca3.87<==>B263/conf90.52/nbdist_avg_ca3.86/dist_cb45.69, range: 0-25.0, rm_score 15.6875, rm_thre 0.0 +Included! Satisfied! A352/conf74.21/nbdist_avg_ca4.02<==>F610/conf65.76/nbdist_avg_ca3.79/dist_cb20.67, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A360/conf75.39/nbdist_avg_ca3.98<==>F612/conf50.48/nbdist_avg_ca3.85/dist_cb22.97, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! A361/conf83.49/nbdist_avg_ca3.86<==>B126/conf88.13/nbdist_avg_ca3.90/dist_cb77.44, range: 0-25.0, rm_score 47.4375, rm_thre 0.0 +Excluded! Violated! B7/conf58.90/nbdist_avg_ca3.66<==>C210/conf81.56/nbdist_avg_ca3.85/dist_cb76.56, range: 0-25.0, rm_score 46.5625, rm_thre 0.0 +Included! Satisfied! B7/conf58.90/nbdist_avg_ca3.66<==>H529/conf80.11/nbdist_avg_ca3.87/dist_cb18.12, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! B14/conf85.85/nbdist_avg_ca3.69<==>F574/conf76.56/nbdist_avg_ca3.90/dist_cb94.94, range: 0-25.0, rm_score 64.9375, rm_thre 0.0 +Excluded! Violated! B19/conf88.88/nbdist_avg_ca3.73<==>H420/conf74.43/nbdist_avg_ca3.89/dist_cb30.45, range: 0-25.0, rm_score 0.453125, rm_thre 0.0 +Included! Satisfied! B45/conf61.83/nbdist_avg_ca3.93<==>D338/conf80.29/nbdist_avg_ca3.88/dist_cb23.92, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B47/conf51.34/nbdist_avg_ca3.76<==>H458/conf69.88/nbdist_avg_ca3.80/dist_cb12.79, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B49/conf50.61/nbdist_avg_ca3.65<==>H452/conf53.38/nbdist_avg_ca3.63/dist_cb21.58, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B52/conf57.78/nbdist_avg_ca3.89<==>H439/conf73.78/nbdist_avg_ca3.73/dist_cb24.02, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B52/conf57.78/nbdist_avg_ca3.89<==>H467/conf76.38/nbdist_avg_ca4.03/dist_cb11.34, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B52/conf57.78/nbdist_avg_ca3.89<==>H473/conf73.81/nbdist_avg_ca3.76/dist_cb18.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B54/conf69.12/nbdist_avg_ca3.93<==>C268/conf76.98/nbdist_avg_ca4.00/dist_cb24.80, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B54/conf69.12/nbdist_avg_ca3.93<==>H478/conf64.28/nbdist_avg_ca3.63/dist_cb14.81, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B54/conf69.12/nbdist_avg_ca3.93<==>H499/conf72.36/nbdist_avg_ca3.89/dist_cb24.86, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B60/conf77.50/nbdist_avg_ca3.79<==>D293/conf76.74/nbdist_avg_ca3.88/dist_cb18.08, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B62/conf73.11/nbdist_avg_ca3.88<==>D141/conf79.74/nbdist_avg_ca3.89/dist_cb23.58, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B68/conf74.15/nbdist_avg_ca3.98<==>C284/conf84.71/nbdist_avg_ca3.81/dist_cb22.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B68/conf74.15/nbdist_avg_ca3.98<==>C314/conf87.84/nbdist_avg_ca3.79/dist_cb23.83, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B68/conf74.15/nbdist_avg_ca3.98<==>D290/conf74.07/nbdist_avg_ca4.07/dist_cb13.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B68/conf74.15/nbdist_avg_ca3.98<==>D292/conf78.18/nbdist_avg_ca3.88/dist_cb19.12, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! B87/conf87.66/nbdist_avg_ca3.89<==>D296/conf80.75/nbdist_avg_ca3.85/dist_cb33.56, range: 0-25.0, rm_score 3.5625, rm_thre 0.0 +Excluded! Violated! B88/conf88.88/nbdist_avg_ca3.79<==>G569/conf78.07/nbdist_avg_ca3.89/dist_cb103.62, range: 0-25.0, rm_score 73.625, rm_thre 0.0 +Excluded! Violated! B91/conf85.51/nbdist_avg_ca3.90<==>E255/conf83.11/nbdist_avg_ca3.86/dist_cb98.44, range: 0-25.0, rm_score 68.4375, rm_thre 0.0 +Included! Violated! B92/conf79.02/nbdist_avg_ca3.79<==>H601/conf74.56/nbdist_avg_ca3.83/dist_cb29.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B102/conf68.78/nbdist_avg_ca4.24<==>H494/conf74.37/nbdist_avg_ca3.74/dist_cb24.25, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! B120/conf87.94/nbdist_avg_ca3.83<==>E115/conf77.85/nbdist_avg_ca3.87/dist_cb85.06, range: 0-25.0, rm_score 55.0625, rm_thre 0.0 +Included! Satisfied! B190/conf89.41/nbdist_avg_ca3.95<==>D290/conf74.07/nbdist_avg_ca4.07/dist_cb23.09, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B193/conf83.12/nbdist_avg_ca3.95<==>C111/conf82.60/nbdist_avg_ca3.83/dist_cb11.49, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B193/conf83.12/nbdist_avg_ca3.95<==>C371/conf76.63/nbdist_avg_ca3.75/dist_cb23.30, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! B200/conf74.76/nbdist_avg_ca3.77<==>C134/conf89.49/nbdist_avg_ca3.88/dist_cb27.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B200/conf74.76/nbdist_avg_ca3.77<==>D295/conf83.60/nbdist_avg_ca3.89/dist_cb21.88, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B201/conf69.31/nbdist_avg_ca3.92<==>C83/conf83.84/nbdist_avg_ca3.81/dist_cb21.28, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B204/conf74.54/nbdist_avg_ca3.93<==>C322/conf84.78/nbdist_avg_ca3.91/dist_cb21.41, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B214/conf86.86/nbdist_avg_ca3.82<==>D328/conf79.35/nbdist_avg_ca3.69/dist_cb23.62, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B217/conf87.16/nbdist_avg_ca3.87<==>D325/conf75.44/nbdist_avg_ca3.97/dist_cb22.30, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! B241/conf84.92/nbdist_avg_ca3.79<==>C373/conf73.76/nbdist_avg_ca3.92/dist_cb25.02, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B251/conf83.94/nbdist_avg_ca3.87<==>D323/conf85.09/nbdist_avg_ca3.93/dist_cb19.97, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B252/conf86.76/nbdist_avg_ca3.88<==>D325/conf75.44/nbdist_avg_ca3.97/dist_cb21.02, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B254/conf76.62/nbdist_avg_ca4.01<==>C79/conf77.10/nbdist_avg_ca3.96/dist_cb23.44, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! B330/conf86.88/nbdist_avg_ca3.81<==>E291/conf75.38/nbdist_avg_ca3.89/dist_cb82.25, range: 0-25.0, rm_score 52.25, rm_thre 0.0 +Included! Violated! B339/conf81.73/nbdist_avg_ca3.96<==>H479/conf61.07/nbdist_avg_ca3.86/dist_cb26.19, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! B346/conf87.67/nbdist_avg_ca3.82<==>F535/conf85.59/nbdist_avg_ca3.98/dist_cb107.12, range: 0-25.0, rm_score 77.125, rm_thre 0.0 +Excluded! Violated! B360/conf88.00/nbdist_avg_ca3.77<==>F503/conf67.55/nbdist_avg_ca4.02/dist_cb83.88, range: 0-25.0, rm_score 53.875, rm_thre 0.0 +Included! Violated! C8/conf67.08/nbdist_avg_ca3.72<==>G601/conf71.87/nbdist_avg_ca3.81/dist_cb27.84, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C11/conf79.07/nbdist_avg_ca3.97<==>F443/conf56.19/nbdist_avg_ca4.05/dist_cb20.73, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C20/conf84.09/nbdist_avg_ca3.78<==>F455/conf62.60/nbdist_avg_ca3.65/dist_cb22.31, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C24/conf72.05/nbdist_avg_ca3.83<==>G547/conf62.33/nbdist_avg_ca3.91/dist_cb17.23, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C27/conf65.95/nbdist_avg_ca3.94<==>G492/conf62.87/nbdist_avg_ca4.04/dist_cb22.66, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C43/conf64.67/nbdist_avg_ca3.81<==>E296/conf79.67/nbdist_avg_ca4.32/dist_cb24.61, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C45/conf63.25/nbdist_avg_ca3.84<==>E360/conf84.46/nbdist_avg_ca3.74/dist_cb18.95, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C49/conf48.10/nbdist_avg_ca3.64<==>G452/conf52.50/nbdist_avg_ca3.77/dist_cb21.72, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C52/conf59.04/nbdist_avg_ca3.68<==>E351/conf69.40/nbdist_avg_ca3.92/dist_cb20.34, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C54/conf66.82/nbdist_avg_ca3.94<==>G475/conf71.59/nbdist_avg_ca3.88/dist_cb19.28, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C61/conf67.90/nbdist_avg_ca4.17<==>G492/conf62.87/nbdist_avg_ca4.04/dist_cb16.25, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C62/conf57.11/nbdist_avg_ca4.16<==>E325/conf73.76/nbdist_avg_ca3.91/dist_cb24.34, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C62/conf57.11/nbdist_avg_ca4.16<==>G460/conf68.64/nbdist_avg_ca4.05/dist_cb17.66, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C63/conf68.08/nbdist_avg_ca4.49<==>G421/conf68.24/nbdist_avg_ca4.13/dist_cb23.44, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C63/conf68.08/nbdist_avg_ca4.49<==>G500/conf62.57/nbdist_avg_ca4.05/dist_cb22.34, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C68/conf74.76/nbdist_avg_ca3.92<==>E283/conf85.22/nbdist_avg_ca3.80/dist_cb22.33, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C68/conf74.76/nbdist_avg_ca3.92<==>E287/conf75.37/nbdist_avg_ca3.78/dist_cb16.05, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C71/conf83.70/nbdist_avg_ca3.84<==>D285/conf76.41/nbdist_avg_ca3.92/dist_cb16.62, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C73/conf78.34/nbdist_avg_ca3.80<==>D279/conf85.08/nbdist_avg_ca3.93/dist_cb24.09, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C86/conf77.82/nbdist_avg_ca4.12<==>D274/conf79.42/nbdist_avg_ca3.78/dist_cb24.64, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C86/conf77.82/nbdist_avg_ca4.12<==>G501/conf70.39/nbdist_avg_ca4.21/dist_cb24.92, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C90/conf75.43/nbdist_avg_ca4.14<==>G492/conf62.87/nbdist_avg_ca4.04/dist_cb14.47, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C93/conf68.85/nbdist_avg_ca4.14<==>G438/conf73.57/nbdist_avg_ca3.87/dist_cb20.47, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C93/conf68.85/nbdist_avg_ca4.14<==>G465/conf74.53/nbdist_avg_ca3.76/dist_cb22.69, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! C125/conf88.82/nbdist_avg_ca3.85<==>D346/conf79.41/nbdist_avg_ca3.77/dist_cb66.00, range: 0-25.0, rm_score 36.0, rm_thre 0.0 +Included! Satisfied! C129/conf75.29/nbdist_avg_ca4.15<==>G506/conf54.08/nbdist_avg_ca3.95/dist_cb24.53, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C196/conf77.05/nbdist_avg_ca3.91<==>D141/conf79.74/nbdist_avg_ca3.89/dist_cb19.44, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C196/conf77.05/nbdist_avg_ca3.91<==>D370/conf78.90/nbdist_avg_ca4.03/dist_cb19.08, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C196/conf77.05/nbdist_avg_ca3.91<==>E281/conf86.42/nbdist_avg_ca3.93/dist_cb24.95, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C200/conf75.56/nbdist_avg_ca3.79<==>E283/conf85.22/nbdist_avg_ca3.80/dist_cb22.61, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C201/conf73.35/nbdist_avg_ca3.80<==>D192/conf86.92/nbdist_avg_ca3.78/dist_cb23.95, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C202/conf78.49/nbdist_avg_ca3.92<==>D305/conf86.34/nbdist_avg_ca3.81/dist_cb23.73, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C204/conf75.52/nbdist_avg_ca3.94<==>D305/conf86.34/nbdist_avg_ca3.81/dist_cb22.05, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C212/conf87.61/nbdist_avg_ca3.89<==>D273/conf76.92/nbdist_avg_ca3.67/dist_cb19.06, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C236/conf77.65/nbdist_avg_ca3.81<==>E326/conf73.64/nbdist_avg_ca3.87/dist_cb23.20, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C237/conf74.30/nbdist_avg_ca3.70<==>D115/conf77.15/nbdist_avg_ca3.93/dist_cb19.88, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C241/conf82.67/nbdist_avg_ca3.77<==>D77/conf78.19/nbdist_avg_ca3.85/dist_cb22.03, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! C241/conf82.67/nbdist_avg_ca3.77<==>H490/conf74.14/nbdist_avg_ca3.77/dist_cb59.69, range: 0-25.0, rm_score 29.6875, rm_thre 0.0 +Included! Satisfied! C249/conf82.26/nbdist_avg_ca3.87<==>E317/conf86.42/nbdist_avg_ca4.20/dist_cb22.06, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C262/conf89.17/nbdist_avg_ca3.84<==>D180/conf81.69/nbdist_avg_ca3.90/dist_cb23.95, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C268/conf76.98/nbdist_avg_ca4.00<==>D281/conf87.23/nbdist_avg_ca3.99/dist_cb24.09, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! C286/conf77.84/nbdist_avg_ca3.84<==>D139/conf83.41/nbdist_avg_ca3.80/dist_cb47.72, range: 0-25.0, rm_score 17.71875, rm_thre 0.0 +Included! Satisfied! C305/conf86.49/nbdist_avg_ca3.86<==>D286/conf78.16/nbdist_avg_ca3.88/dist_cb20.84, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C346/conf78.58/nbdist_avg_ca3.84<==>F442/conf66.81/nbdist_avg_ca4.13/dist_cb22.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C346/conf78.58/nbdist_avg_ca3.84<==>F452/conf57.23/nbdist_avg_ca3.63/dist_cb14.21, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D28/conf79.20/nbdist_avg_ca3.88<==>H458/conf69.88/nbdist_avg_ca3.80/dist_cb18.19, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D54/conf80.45/nbdist_avg_ca3.87<==>E265/conf88.52/nbdist_avg_ca3.84/dist_cb22.08, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D70/conf82.33/nbdist_avg_ca3.84<==>E225/conf86.10/nbdist_avg_ca3.92/dist_cb23.97, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D71/conf83.73/nbdist_avg_ca3.84<==>E267/conf79.80/nbdist_avg_ca3.96/dist_cb20.02, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D142/conf76.72/nbdist_avg_ca3.84<==>H452/conf53.38/nbdist_avg_ca3.63/dist_cb16.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D145/conf75.30/nbdist_avg_ca3.71<==>H425/conf75.49/nbdist_avg_ca3.99/dist_cb18.91, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! D182/conf81.66/nbdist_avg_ca3.81<==>E273/conf75.15/nbdist_avg_ca3.73/dist_cb25.05, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D193/conf83.74/nbdist_avg_ca3.83<==>E109/conf81.77/nbdist_avg_ca3.77/dist_cb17.14, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D200/conf75.96/nbdist_avg_ca3.78<==>E74/conf80.61/nbdist_avg_ca3.85/dist_cb14.73, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D200/conf75.96/nbdist_avg_ca3.78<==>E374/conf73.09/nbdist_avg_ca4.02/dist_cb23.95, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D202/conf76.71/nbdist_avg_ca3.84<==>E280/conf87.67/nbdist_avg_ca3.83/dist_cb21.30, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D204/conf75.64/nbdist_avg_ca3.87<==>E272/conf74.24/nbdist_avg_ca3.79/dist_cb4.79, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D210/conf81.57/nbdist_avg_ca3.76<==>E273/conf75.15/nbdist_avg_ca3.73/dist_cb16.61, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D233/conf76.43/nbdist_avg_ca3.86<==>E364/conf82.98/nbdist_avg_ca3.86/dist_cb17.47, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D244/conf84.59/nbdist_avg_ca3.93<==>E79/conf79.32/nbdist_avg_ca3.84/dist_cb19.44, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D250/conf84.12/nbdist_avg_ca3.88<==>E16/conf81.63/nbdist_avg_ca3.83/dist_cb23.56, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D251/conf82.11/nbdist_avg_ca3.86<==>E371/conf76.85/nbdist_avg_ca3.82/dist_cb24.19, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! D251/conf82.11/nbdist_avg_ca3.86<==>F428/conf78.91/nbdist_avg_ca3.75/dist_cb99.62, range: 0-25.0, rm_score 69.625, rm_thre 0.0 +Included! Satisfied! D262/conf89.81/nbdist_avg_ca3.85<==>E288/conf74.60/nbdist_avg_ca3.76/dist_cb18.12, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! D265/conf89.34/nbdist_avg_ca3.83<==>E235/conf79.06/nbdist_avg_ca3.91/dist_cb60.03, range: 0-25.0, rm_score 30.03125, rm_thre 0.0 +Included! Satisfied! D272/conf72.62/nbdist_avg_ca3.85<==>E281/conf86.42/nbdist_avg_ca3.93/dist_cb22.73, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! D306/conf87.30/nbdist_avg_ca3.88<==>H588/conf88.44/nbdist_avg_ca3.82/dist_cb69.69, range: 0-25.0, rm_score 39.6875, rm_thre 0.0 +Included! Satisfied! D330/conf77.42/nbdist_avg_ca3.94<==>H487/conf76.09/nbdist_avg_ca3.94/dist_cb20.88, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! D339/conf79.66/nbdist_avg_ca3.87<==>G576/conf66.50/nbdist_avg_ca3.74/dist_cb96.81, range: 0-25.0, rm_score 66.8125, rm_thre 0.0 +Included! Satisfied! D346/conf79.41/nbdist_avg_ca3.77<==>H444/conf74.18/nbdist_avg_ca3.84/dist_cb19.41, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D352/conf66.97/nbdist_avg_ca3.96<==>H490/conf74.14/nbdist_avg_ca3.77/dist_cb24.45, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D353/conf62.94/nbdist_avg_ca4.10<==>H484/conf69.82/nbdist_avg_ca3.83/dist_cb14.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E7/conf60.42/nbdist_avg_ca3.71<==>G450/conf54.53/nbdist_avg_ca3.64/dist_cb20.50, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! E52/conf77.93/nbdist_avg_ca3.71<==>H571/conf82.16/nbdist_avg_ca3.83/dist_cb134.25, range: 0-25.0, rm_score 104.25, rm_thre 0.0 +Included! Satisfied! E108/conf83.29/nbdist_avg_ca3.79<==>G456/conf61.81/nbdist_avg_ca3.92/dist_cb21.66, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E132/conf85.96/nbdist_avg_ca3.81<==>G443/conf71.18/nbdist_avg_ca3.85/dist_cb19.56, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! E143/conf76.83/nbdist_avg_ca3.84<==>G496/conf65.90/nbdist_avg_ca3.99/dist_cb34.03, range: 0-25.0, rm_score 4.03125, rm_thre 0.0 +Included! Satisfied! E144/conf74.61/nbdist_avg_ca3.70<==>G459/conf66.24/nbdist_avg_ca3.92/dist_cb16.27, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! E146/conf72.00/nbdist_avg_ca3.81<==>F465/conf75.67/nbdist_avg_ca3.70/dist_cb73.38, range: 0-25.0, rm_score 43.375, rm_thre 0.0 +Included! Satisfied! E147/conf70.71/nbdist_avg_ca3.84<==>G460/conf68.64/nbdist_avg_ca4.05/dist_cb13.85, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! E217/conf84.46/nbdist_avg_ca3.81<==>G571/conf75.30/nbdist_avg_ca3.84/dist_cb84.31, range: 0-25.0, rm_score 54.3125, rm_thre 0.0 +Excluded! Violated! E230/conf88.39/nbdist_avg_ca3.89<==>F613/conf64.72/nbdist_avg_ca3.94/dist_cb113.06, range: 0-25.0, rm_score 83.0625, rm_thre 0.0 +Excluded! Violated! E237/conf78.58/nbdist_avg_ca3.80<==>F502/conf66.05/nbdist_avg_ca3.92/dist_cb110.81, range: 0-25.0, rm_score 80.8125, rm_thre 0.0 +Excluded! Violated! E273/conf75.15/nbdist_avg_ca3.73<==>H431/conf83.45/nbdist_avg_ca3.83/dist_cb70.62, range: 0-25.0, rm_score 40.625, rm_thre 0.0 +Excluded! Violated! E278/conf85.06/nbdist_avg_ca3.87<==>G613/conf70.01/nbdist_avg_ca3.71/dist_cb57.28, range: 0-25.0, rm_score 27.28125, rm_thre 0.0 +Excluded! Violated! E288/conf74.60/nbdist_avg_ca3.76<==>F611/conf62.57/nbdist_avg_ca3.73/dist_cb87.44, range: 0-25.0, rm_score 57.4375, rm_thre 0.0 +Included! Satisfied! E294/conf72.93/nbdist_avg_ca4.52<==>G450/conf54.53/nbdist_avg_ca3.64/dist_cb23.55, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E297/conf66.15/nbdist_avg_ca4.30<==>G420/conf64.45/nbdist_avg_ca4.00/dist_cb20.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! E316/conf83.99/nbdist_avg_ca4.24<==>G450/conf54.53/nbdist_avg_ca3.64/dist_cb28.62, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E330/conf73.99/nbdist_avg_ca4.05<==>G457/conf65.40/nbdist_avg_ca3.92/dist_cb19.36, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! E338/conf80.62/nbdist_avg_ca3.86<==>F611/conf62.57/nbdist_avg_ca3.73/dist_cb111.31, range: 0-25.0, rm_score 81.3125, rm_thre 0.0 +Included! Satisfied! E353/conf65.33/nbdist_avg_ca4.11<==>G472/conf73.60/nbdist_avg_ca3.80/dist_cb14.32, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E353/conf65.33/nbdist_avg_ca4.11<==>G494/conf69.50/nbdist_avg_ca4.03/dist_cb21.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Excluded! Violated! F499/conf71.74/nbdist_avg_ca3.83<==>G578/conf66.45/nbdist_avg_ca3.66/dist_cb43.22, range: 0-25.0, rm_score 13.21875, rm_thre 0.0 +>>>>> Total 189: 152 included, 144 satisfied +Breakage info ========== +Break number: 2, Max neighbour CA dist: 5.6640625 + +Recall info============= +interchain (w 1): recall 0.7619047618644494, recall weighted by confidence: 0.7512976084061713 +[WARNING] CORE(2128692,ffff907d5020,python):2025-03-10-11:16:17.647.697 [mindspore/core/include/ir/base_tensor.h:85] NewData] Try to alloca a large memory, size is:4294967296 +num_recycle is 20 +start recycle_cond +recycle 0 diff: 0.0001 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 1 diff: 84.20224933251642 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 2 diff: 8.531760285440317 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 3 diff: 2.998889868861365 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 4 diff: 1.8990718744742177 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 5 diff: 1.5802629971343825 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 6 diff: 1.2450133119931146 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 7 diff: 1.0525802643577602 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 8 diff: 1.0274764101442193 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 9 diff: 0.9284488014707725 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 10 diff: 0.5978580326863305 +end recycle_cond: True +--------------------start---------------------- +--------------------end------------------------ +start recycle_cond +recycle 11 diff: 0.47195285231080886 +end recycle_cond: False +early stop: 11 + ===================== pdb_path ==================== ./compare_with_parallel/test6_4096_iter2_recycle20_graph_parallel.pdb +Filter Restraints Iteration 2 ============================================= +inter-residue restraints: 152(152 inter-chain + 0 intra-chain) +Inter-chain restraints +Included! Satisfied! A19/conf86.28/nbdist_avg_ca3.83<==>F477/conf57.81/nbdist_avg_ca3.99/dist_cb18.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A20/conf82.55/nbdist_avg_ca3.72<==>F481/conf53.03/nbdist_avg_ca3.64/dist_cb22.64, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A21/conf84.59/nbdist_avg_ca3.51<==>F611/conf69.38/nbdist_avg_ca3.91/dist_cb22.17, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A26/conf78.18/nbdist_avg_ca3.88<==>F477/conf57.81/nbdist_avg_ca3.99/dist_cb22.12, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A43/conf67.06/nbdist_avg_ca3.73<==>C370/conf79.39/nbdist_avg_ca3.91/dist_cb17.56, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A52/conf59.21/nbdist_avg_ca3.85<==>B271/conf75.23/nbdist_avg_ca3.85/dist_cb24.14, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A52/conf59.21/nbdist_avg_ca3.85<==>F466/conf74.17/nbdist_avg_ca3.83/dist_cb15.97, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A52/conf59.21/nbdist_avg_ca3.85<==>F473/conf72.32/nbdist_avg_ca3.80/dist_cb18.81, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A54/conf72.29/nbdist_avg_ca3.86<==>F467/conf73.62/nbdist_avg_ca3.91/dist_cb15.64, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A58/conf75.01/nbdist_avg_ca3.84<==>C293/conf80.21/nbdist_avg_ca3.84/dist_cb18.14, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A58/conf75.01/nbdist_avg_ca3.84<==>F477/conf57.81/nbdist_avg_ca3.99/dist_cb18.09, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A61/conf74.92/nbdist_avg_ca4.04<==>C293/conf80.21/nbdist_avg_ca3.84/dist_cb15.23, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A62/conf66.72/nbdist_avg_ca4.37<==>F486/conf73.79/nbdist_avg_ca3.82/dist_cb16.31, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A62/conf66.72/nbdist_avg_ca4.37<==>F495/conf68.56/nbdist_avg_ca3.74/dist_cb18.72, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A66/conf69.73/nbdist_avg_ca3.96<==>F459/conf71.05/nbdist_avg_ca3.74/dist_cb23.50, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A68/conf77.25/nbdist_avg_ca3.81<==>B322/conf90.03/nbdist_avg_ca3.82/dist_cb18.75, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A69/conf76.07/nbdist_avg_ca3.71<==>B283/conf87.63/nbdist_avg_ca3.79/dist_cb21.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A70/conf77.38/nbdist_avg_ca3.70<==>F484/conf67.76/nbdist_avg_ca3.84/dist_cb24.30, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A79/conf84.06/nbdist_avg_ca3.96<==>B291/conf80.31/nbdist_avg_ca3.83/dist_cb24.66, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A79/conf84.06/nbdist_avg_ca3.96<==>B327/conf84.33/nbdist_avg_ca3.74/dist_cb22.44, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A91/conf79.54/nbdist_avg_ca3.83<==>F502/conf73.90/nbdist_avg_ca3.73/dist_cb24.53, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! A93/conf72.48/nbdist_avg_ca3.84<==>F425/conf74.86/nbdist_avg_ca3.90/dist_cb26.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A105/conf86.35/nbdist_avg_ca3.83<==>F611/conf69.38/nbdist_avg_ca3.91/dist_cb24.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A132/conf70.46/nbdist_avg_ca4.41<==>F477/conf57.81/nbdist_avg_ca3.99/dist_cb16.91, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A132/conf70.46/nbdist_avg_ca4.41<==>F611/conf69.38/nbdist_avg_ca3.91/dist_cb21.97, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A134/conf89.13/nbdist_avg_ca4.00<==>F477/conf57.81/nbdist_avg_ca3.99/dist_cb20.80, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A181/conf81.97/nbdist_avg_ca3.93<==>B283/conf87.63/nbdist_avg_ca3.79/dist_cb18.75, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A189/conf87.89/nbdist_avg_ca3.94<==>B267/conf79.29/nbdist_avg_ca3.95/dist_cb22.73, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A190/conf87.83/nbdist_avg_ca3.85<==>B372/conf80.62/nbdist_avg_ca3.80/dist_cb22.34, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A193/conf84.63/nbdist_avg_ca3.83<==>B114/conf79.07/nbdist_avg_ca3.85/dist_cb12.41, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A196/conf77.08/nbdist_avg_ca3.89<==>B186/conf87.59/nbdist_avg_ca3.79/dist_cb18.64, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A196/conf77.08/nbdist_avg_ca3.89<==>B372/conf80.62/nbdist_avg_ca3.80/dist_cb17.86, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A200/conf75.14/nbdist_avg_ca3.79<==>B16/conf83.22/nbdist_avg_ca3.77/dist_cb19.58, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A204/conf77.23/nbdist_avg_ca3.87<==>B183/conf88.64/nbdist_avg_ca3.79/dist_cb17.84, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A205/conf76.98/nbdist_avg_ca3.83<==>B282/conf88.75/nbdist_avg_ca3.89/dist_cb14.27, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A210/conf84.59/nbdist_avg_ca3.82<==>C320/conf86.74/nbdist_avg_ca3.86/dist_cb24.06, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A214/conf87.59/nbdist_avg_ca3.84<==>B370/conf79.08/nbdist_avg_ca4.02/dist_cb24.67, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A235/conf76.60/nbdist_avg_ca3.83<==>B366/conf84.70/nbdist_avg_ca3.79/dist_cb21.75, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A246/conf72.93/nbdist_avg_ca4.05<==>C280/conf90.43/nbdist_avg_ca3.84/dist_cb15.36, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A246/conf72.93/nbdist_avg_ca4.05<==>C326/conf77.27/nbdist_avg_ca3.83/dist_cb9.06, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A252/conf84.81/nbdist_avg_ca3.88<==>B122/conf87.07/nbdist_avg_ca3.77/dist_cb23.12, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A262/conf92.31/nbdist_avg_ca3.77<==>B284/conf88.17/nbdist_avg_ca3.83/dist_cb22.41, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A352/conf75.80/nbdist_avg_ca3.99<==>F610/conf72.68/nbdist_avg_ca3.87/dist_cb21.14, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! A360/conf81.19/nbdist_avg_ca3.91<==>F612/conf54.84/nbdist_avg_ca3.95/dist_cb23.89, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B7/conf59.60/nbdist_avg_ca3.64<==>H529/conf82.94/nbdist_avg_ca3.83/dist_cb18.05, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B45/conf65.56/nbdist_avg_ca3.90<==>D338/conf83.64/nbdist_avg_ca3.89/dist_cb23.75, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B47/conf55.95/nbdist_avg_ca3.70<==>H458/conf71.92/nbdist_avg_ca3.79/dist_cb12.81, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B49/conf54.07/nbdist_avg_ca3.64<==>H452/conf62.02/nbdist_avg_ca3.77/dist_cb21.75, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B52/conf60.96/nbdist_avg_ca3.77<==>H439/conf74.43/nbdist_avg_ca3.71/dist_cb24.22, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B52/conf60.96/nbdist_avg_ca3.77<==>H467/conf78.14/nbdist_avg_ca3.96/dist_cb11.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B52/conf60.96/nbdist_avg_ca3.77<==>H473/conf74.49/nbdist_avg_ca3.76/dist_cb18.86, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B54/conf73.85/nbdist_avg_ca3.90<==>C268/conf77.85/nbdist_avg_ca4.00/dist_cb24.36, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B54/conf73.85/nbdist_avg_ca3.90<==>H478/conf67.62/nbdist_avg_ca3.67/dist_cb15.02, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! B54/conf73.85/nbdist_avg_ca3.90<==>H499/conf73.71/nbdist_avg_ca3.89/dist_cb25.16, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B60/conf79.74/nbdist_avg_ca3.78<==>D293/conf79.23/nbdist_avg_ca3.86/dist_cb18.30, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B62/conf74.30/nbdist_avg_ca3.88<==>D141/conf81.63/nbdist_avg_ca3.87/dist_cb23.36, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B68/conf76.22/nbdist_avg_ca3.96<==>C284/conf85.97/nbdist_avg_ca3.82/dist_cb22.64, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B68/conf76.22/nbdist_avg_ca3.96<==>C314/conf89.20/nbdist_avg_ca3.79/dist_cb23.89, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B68/conf76.22/nbdist_avg_ca3.96<==>D290/conf74.48/nbdist_avg_ca4.03/dist_cb13.49, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B68/conf76.22/nbdist_avg_ca3.96<==>D292/conf80.43/nbdist_avg_ca3.88/dist_cb19.09, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! B92/conf81.87/nbdist_avg_ca3.80<==>H601/conf75.74/nbdist_avg_ca3.83/dist_cb29.83, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B102/conf70.96/nbdist_avg_ca4.11<==>H494/conf75.11/nbdist_avg_ca3.76/dist_cb24.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B190/conf90.17/nbdist_avg_ca3.94<==>D290/conf74.48/nbdist_avg_ca4.03/dist_cb22.84, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B193/conf83.56/nbdist_avg_ca3.93<==>C111/conf83.57/nbdist_avg_ca3.83/dist_cb11.48, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B193/conf83.56/nbdist_avg_ca3.93<==>C371/conf76.83/nbdist_avg_ca3.74/dist_cb23.20, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! B200/conf75.81/nbdist_avg_ca3.77<==>C134/conf90.83/nbdist_avg_ca3.87/dist_cb27.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B200/conf75.81/nbdist_avg_ca3.77<==>D295/conf86.05/nbdist_avg_ca3.90/dist_cb21.72, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B201/conf70.68/nbdist_avg_ca3.89<==>C83/conf85.11/nbdist_avg_ca3.81/dist_cb21.27, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B204/conf77.01/nbdist_avg_ca3.91<==>C322/conf86.73/nbdist_avg_ca3.89/dist_cb21.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B214/conf87.11/nbdist_avg_ca3.82<==>D328/conf80.88/nbdist_avg_ca3.67/dist_cb23.23, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B217/conf88.48/nbdist_avg_ca3.88<==>D325/conf77.35/nbdist_avg_ca3.95/dist_cb22.06, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B241/conf86.01/nbdist_avg_ca3.78<==>C373/conf74.26/nbdist_avg_ca3.92/dist_cb24.91, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B251/conf85.96/nbdist_avg_ca3.86<==>D323/conf86.38/nbdist_avg_ca3.92/dist_cb19.73, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B252/conf87.41/nbdist_avg_ca3.88<==>D325/conf77.35/nbdist_avg_ca3.95/dist_cb20.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! B254/conf77.75/nbdist_avg_ca4.01<==>C79/conf78.29/nbdist_avg_ca3.95/dist_cb23.47, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! B339/conf83.38/nbdist_avg_ca3.99<==>H479/conf63.39/nbdist_avg_ca3.87/dist_cb25.75, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! C8/conf68.06/nbdist_avg_ca3.71<==>G601/conf74.52/nbdist_avg_ca3.80/dist_cb27.62, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C11/conf83.60/nbdist_avg_ca3.87<==>F443/conf60.81/nbdist_avg_ca3.98/dist_cb20.36, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C20/conf86.00/nbdist_avg_ca3.77<==>F455/conf68.52/nbdist_avg_ca3.63/dist_cb21.55, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C24/conf74.65/nbdist_avg_ca3.83<==>G547/conf62.74/nbdist_avg_ca4.08/dist_cb16.67, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C27/conf70.68/nbdist_avg_ca3.90<==>G492/conf68.16/nbdist_avg_ca4.04/dist_cb22.00, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C43/conf67.61/nbdist_avg_ca3.79<==>E296/conf83.35/nbdist_avg_ca4.12/dist_cb24.44, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C45/conf67.49/nbdist_avg_ca3.85<==>E360/conf86.23/nbdist_avg_ca3.73/dist_cb19.33, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C49/conf51.65/nbdist_avg_ca3.62<==>G452/conf59.98/nbdist_avg_ca3.87/dist_cb21.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C52/conf62.86/nbdist_avg_ca3.67<==>E351/conf72.11/nbdist_avg_ca3.88/dist_cb20.17, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C54/conf72.43/nbdist_avg_ca3.89<==>G475/conf72.58/nbdist_avg_ca3.86/dist_cb19.61, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C61/conf71.83/nbdist_avg_ca3.99<==>G492/conf68.16/nbdist_avg_ca4.04/dist_cb15.87, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C62/conf65.29/nbdist_avg_ca3.96<==>E325/conf73.93/nbdist_avg_ca3.88/dist_cb24.33, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C62/conf65.29/nbdist_avg_ca3.96<==>G460/conf72.79/nbdist_avg_ca3.95/dist_cb17.81, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C63/conf72.46/nbdist_avg_ca4.20<==>G421/conf73.07/nbdist_avg_ca4.02/dist_cb23.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C63/conf72.46/nbdist_avg_ca4.20<==>G500/conf68.49/nbdist_avg_ca4.03/dist_cb22.81, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C68/conf77.27/nbdist_avg_ca3.91<==>E283/conf87.53/nbdist_avg_ca3.81/dist_cb22.20, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C68/conf77.27/nbdist_avg_ca3.91<==>E287/conf79.32/nbdist_avg_ca3.82/dist_cb15.92, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C71/conf85.90/nbdist_avg_ca3.83<==>D285/conf78.15/nbdist_avg_ca3.91/dist_cb16.62, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C73/conf79.99/nbdist_avg_ca3.79<==>D279/conf86.25/nbdist_avg_ca3.93/dist_cb23.92, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C86/conf79.75/nbdist_avg_ca4.04<==>D274/conf82.08/nbdist_avg_ca3.77/dist_cb24.30, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! C86/conf79.75/nbdist_avg_ca4.04<==>G501/conf73.07/nbdist_avg_ca4.08/dist_cb25.22, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C90/conf78.71/nbdist_avg_ca4.05<==>G492/conf68.16/nbdist_avg_ca4.04/dist_cb14.57, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C93/conf72.16/nbdist_avg_ca4.07<==>G438/conf73.78/nbdist_avg_ca3.87/dist_cb20.39, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C93/conf72.16/nbdist_avg_ca4.07<==>G465/conf76.66/nbdist_avg_ca3.76/dist_cb22.47, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! C129/conf78.20/nbdist_avg_ca4.00<==>G506/conf62.60/nbdist_avg_ca3.89/dist_cb25.58, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C196/conf78.19/nbdist_avg_ca3.90<==>D141/conf81.63/nbdist_avg_ca3.87/dist_cb19.36, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C196/conf78.19/nbdist_avg_ca3.90<==>D370/conf79.54/nbdist_avg_ca3.98/dist_cb18.92, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C196/conf78.19/nbdist_avg_ca3.90<==>E281/conf88.84/nbdist_avg_ca3.92/dist_cb24.83, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C200/conf76.54/nbdist_avg_ca3.77<==>E283/conf87.53/nbdist_avg_ca3.81/dist_cb22.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C201/conf75.64/nbdist_avg_ca3.78<==>D192/conf88.39/nbdist_avg_ca3.77/dist_cb24.00, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C202/conf79.56/nbdist_avg_ca3.90<==>D305/conf87.98/nbdist_avg_ca3.82/dist_cb23.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C204/conf77.99/nbdist_avg_ca3.92<==>D305/conf87.98/nbdist_avg_ca3.82/dist_cb22.05, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C212/conf89.16/nbdist_avg_ca3.89<==>D273/conf79.04/nbdist_avg_ca3.70/dist_cb19.05, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C236/conf78.92/nbdist_avg_ca3.81<==>E326/conf75.55/nbdist_avg_ca3.86/dist_cb22.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C237/conf75.72/nbdist_avg_ca3.72<==>D115/conf78.05/nbdist_avg_ca3.89/dist_cb20.08, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C241/conf85.80/nbdist_avg_ca3.75<==>D77/conf79.81/nbdist_avg_ca3.84/dist_cb22.08, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C249/conf84.81/nbdist_avg_ca3.86<==>E317/conf88.20/nbdist_avg_ca4.15/dist_cb22.00, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C262/conf89.95/nbdist_avg_ca3.83<==>D180/conf83.39/nbdist_avg_ca3.89/dist_cb23.95, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C268/conf77.85/nbdist_avg_ca4.00<==>D281/conf88.62/nbdist_avg_ca4.00/dist_cb24.11, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C305/conf86.88/nbdist_avg_ca3.85<==>D286/conf79.82/nbdist_avg_ca3.88/dist_cb20.88, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C346/conf79.76/nbdist_avg_ca3.80<==>F442/conf71.61/nbdist_avg_ca4.05/dist_cb21.80, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! C346/conf79.76/nbdist_avg_ca3.80<==>F452/conf63.91/nbdist_avg_ca3.73/dist_cb13.95, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D28/conf82.30/nbdist_avg_ca3.84<==>H458/conf71.92/nbdist_avg_ca3.79/dist_cb17.61, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D54/conf83.92/nbdist_avg_ca3.85<==>E265/conf90.55/nbdist_avg_ca3.86/dist_cb22.02, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D70/conf85.32/nbdist_avg_ca3.85<==>E225/conf87.97/nbdist_avg_ca3.89/dist_cb23.98, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D71/conf85.67/nbdist_avg_ca3.84<==>E267/conf83.93/nbdist_avg_ca3.91/dist_cb20.03, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D142/conf80.91/nbdist_avg_ca3.82<==>H452/conf62.02/nbdist_avg_ca3.77/dist_cb16.50, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D145/conf77.18/nbdist_avg_ca3.72<==>H425/conf78.86/nbdist_avg_ca3.93/dist_cb18.34, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! D182/conf85.30/nbdist_avg_ca3.83<==>E273/conf78.89/nbdist_avg_ca3.71/dist_cb25.16, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D193/conf85.26/nbdist_avg_ca3.82<==>E109/conf83.14/nbdist_avg_ca3.78/dist_cb17.14, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D200/conf76.12/nbdist_avg_ca3.78<==>E74/conf81.98/nbdist_avg_ca3.84/dist_cb14.77, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D200/conf76.12/nbdist_avg_ca3.78<==>E374/conf73.33/nbdist_avg_ca3.98/dist_cb23.94, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D202/conf78.94/nbdist_avg_ca3.85<==>E280/conf89.41/nbdist_avg_ca3.83/dist_cb21.39, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D204/conf77.38/nbdist_avg_ca3.86<==>E272/conf75.66/nbdist_avg_ca3.82/dist_cb4.84, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D210/conf85.32/nbdist_avg_ca3.76<==>E273/conf78.89/nbdist_avg_ca3.71/dist_cb16.75, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D233/conf77.98/nbdist_avg_ca3.87<==>E364/conf86.15/nbdist_avg_ca3.86/dist_cb17.53, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D244/conf85.81/nbdist_avg_ca3.93<==>E79/conf81.58/nbdist_avg_ca3.82/dist_cb19.36, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D250/conf85.28/nbdist_avg_ca3.87<==>E16/conf82.64/nbdist_avg_ca3.83/dist_cb23.64, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D251/conf84.83/nbdist_avg_ca3.84<==>E371/conf77.88/nbdist_avg_ca3.82/dist_cb24.19, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D262/conf91.09/nbdist_avg_ca3.86<==>E288/conf76.24/nbdist_avg_ca3.79/dist_cb18.22, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D272/conf74.58/nbdist_avg_ca3.84<==>E281/conf88.84/nbdist_avg_ca3.92/dist_cb22.69, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D330/conf82.67/nbdist_avg_ca3.83<==>H487/conf77.20/nbdist_avg_ca3.91/dist_cb20.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D346/conf82.21/nbdist_avg_ca3.75<==>H444/conf75.45/nbdist_avg_ca3.84/dist_cb18.70, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D352/conf69.56/nbdist_avg_ca3.92<==>H490/conf74.05/nbdist_avg_ca3.82/dist_cb24.47, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! D353/conf65.79/nbdist_avg_ca4.10<==>H484/conf72.63/nbdist_avg_ca3.78/dist_cb14.62, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E7/conf60.60/nbdist_avg_ca3.71<==>G450/conf61.19/nbdist_avg_ca3.64/dist_cb21.25, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E108/conf85.17/nbdist_avg_ca3.77<==>G456/conf69.45/nbdist_avg_ca3.89/dist_cb21.52, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E132/conf87.80/nbdist_avg_ca3.80<==>G443/conf73.68/nbdist_avg_ca3.83/dist_cb19.95, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E144/conf80.71/nbdist_avg_ca3.75<==>G459/conf70.80/nbdist_avg_ca3.91/dist_cb15.88, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E147/conf75.02/nbdist_avg_ca3.90<==>G460/conf72.79/nbdist_avg_ca3.95/dist_cb13.56, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E294/conf73.74/nbdist_avg_ca4.31<==>G450/conf61.19/nbdist_avg_ca3.64/dist_cb23.23, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E297/conf73.84/nbdist_avg_ca4.12<==>G420/conf71.51/nbdist_avg_ca3.99/dist_cb20.56, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Violated! E316/conf86.73/nbdist_avg_ca4.13<==>G450/conf61.19/nbdist_avg_ca3.64/dist_cb27.59, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E330/conf77.72/nbdist_avg_ca3.93<==>G457/conf71.79/nbdist_avg_ca3.91/dist_cb19.08, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E353/conf66.52/nbdist_avg_ca3.99<==>G472/conf75.47/nbdist_avg_ca3.79/dist_cb14.50, range: 0-25.0, rm_score 0, rm_thre 0.0 +Included! Satisfied! E353/conf66.52/nbdist_avg_ca3.99<==>G494/conf73.55/nbdist_avg_ca3.95/dist_cb21.98, range: 0-25.0, rm_score 0, rm_thre 0.0 +>>>>> Total 152: 152 included, 142 satisfied +Breakage info ========== +Break number: 0, Max neighbour CA dist: 4.875 + +Recall info============= +interchain (w 1): recall 0.7513227512829987, recall weighted by confidence: 0.7416862316199593 +Stop iteration: Converged +Inference done! +time cost: 6604.073527097702 + +``` + diff --git a/MindSPONGE/applications/research/Grasp/cell/.zip b/MindSPONGE/applications/research/Grasp/cell/.zip new file mode 100644 index 0000000000000000000000000000000000000000..46358076015a655ddfdeadfe704de9a14a9ea879 Binary files /dev/null and b/MindSPONGE/applications/research/Grasp/cell/.zip differ diff --git a/MindSPONGE/applications/research/Grasp/cell/equivariant.py b/MindSPONGE/applications/research/Grasp/cell/equivariant.py new file mode 100644 index 0000000000000000000000000000000000000000..2589efe242230b8cd2ab0bf6a3d673d862431179 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/cell/equivariant.py @@ -0,0 +1,212 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Equivariant""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Parameter +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindsponge1.common.geometry import apply_to_point, invert_point +from mindsponge1.cell.initializer import lecun_init +from mindsponge1.cell.sbr import ProcessSBR +from mindsponge1.cell.interface import AddInterface +# from mindsponge1.cell import AddInterface, ProcessSBR +from common.geometry import multimer_vecs_robust_norm, multimer_square_euclidean_distance + + +class MultimerInvariantPointAttention(nn.Cell): + """Invariant Point attention module.""" + + def __init__(self, num_head, num_scalar_qk, num_scalar_v, num_point_v, num_point_qk, num_channel, pair_dim, + device_num): + """ + + Args: + pair_dim: pair representation dimension. + """ + + super(MultimerInvariantPointAttention, self).__init__() + + self._dist_epsilon = Tensor(1e-8, mstype.float32) + self.num_head = num_head + self.num_scalar_qk = num_scalar_qk + self.num_scalar_v = num_scalar_v + self.num_point_v = num_point_v + self.num_point_qk = num_point_qk + self.num_channel = num_channel + self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 + \ + self.num_head * pair_dim + self.q_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_qk, + weight_init=lecun_init(self.num_channel), has_bias=False) + self.k_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_qk, + weight_init=lecun_init(self.num_channel), has_bias=False) + self.v_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_v, + weight_init=lecun_init(self.num_channel), has_bias=False) + self.q_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_qk, + weight_init=lecun_init(self.num_channel)) + self.k_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_qk, + weight_init=lecun_init(self.num_channel)) + self.v_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_v, + weight_init=lecun_init(self.num_channel)) + self.soft_max = nn.Softmax(axis=-2) + self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights") + self.attention_2d = nn.Dense(pair_dim, self.num_head, weight_init=lecun_init(pair_dim)) + self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros') + self.point_weights = Tensor(np.sqrt(1.0 / (max(num_point_qk, 1) * 9. / 2))) + self.scalar_weights = Tensor(np.sqrt(1.0 / (max(num_scalar_qk, 1) * 1.))) + self.bacth_matmul = P.BatchMatMul().shard(((1, device_num, 1), (1, 1, 1))) + self.bacth_matmul2 = P.BatchMatMul().shard(((device_num, 1, 1), (device_num, 1, 1))) + self.concat_e_6_2 = P.Concat(-1).shard(((device_num, 1), (device_num, 1), (device_num, 1), + (device_num, 1), (device_num, 1), (device_num, 1))) + # interface + self.add_interface = AddInterface(num_channel) + # sbr + self.process_sbr = ProcessSBR(128, self.num_head) + # self.sbr_layer = nn.Dense(128, self.num_head, weight_init='zeros', has_bias=False).to_float(mstype.float16) + # self.trans = P.Transpose().shard(((1, device_num, 1),)) + + def construct(self, inputs_1d, inputs_2d, mask, rotation, translation, sbr_act, sbr_mask, interface_mask): + """Compute geometry-aware attention. + + Args: + inputs_1d: (N, C) 1D input embedding that is the basis for the + scalar queries. + inputs_2d: (N, M, C') 2D input embedding, used for biases and values. + mask: (N, 1) mask to indicate which elements of inputs_1d participate + in the attention. + rotation: describe the orientation of every element in inputs_1d + translation: describe the position of every element in inputs_1d + + Returns: + Transformation of the input embedding. + """ + num_residues, _ = inputs_1d.shape + inputs_1d += self.add_interface(interface_mask, inputs_1d) + num_head = self.num_head + attn_logits = 0. + num_point_qk = self.num_point_qk + point_weights = self.point_weights + trainable_point_weights = mnp.logaddexp(self.trainable_point_weights, + mnp.zeros_like(self.trainable_point_weights)) + point_weights = point_weights * trainable_point_weights + + q_point_local = self.q_point_local(inputs_1d) + q_point_local = mnp.reshape(q_point_local, (num_residues, num_head, num_point_qk * 3)) + q_point_local = mnp.split(q_point_local, 3, axis=-1) + q_point_local = (ops.Squeeze()(q_point_local[0]), ops.Squeeze()(q_point_local[1]), + ops.Squeeze()(q_point_local[2])) + # Project query points into global frame. + q_point_global = apply_to_point(rotation, translation, q_point_local, 2) + q_point = [q_point_global[0][:, None, :, :], q_point_global[1][:, None, :, :], q_point_global[2][:, None, :, :]] + + k_point_local = self.k_point_local(inputs_1d) + k_point_local = mnp.reshape(k_point_local, (num_residues, num_head, num_point_qk * 3)) + k_point_local = mnp.split(k_point_local, 3, axis=-1) + k_point_local = (ops.Squeeze()(k_point_local[0]), ops.Squeeze()(k_point_local[1]), + ops.Squeeze()(k_point_local[2])) + # Project query points into global frame. + k_point_global = apply_to_point(rotation, translation, k_point_local, 2) + k_point = [k_point_global[0][None, :, :, :], k_point_global[1][None, :, :, :], k_point_global[2][None, :, :, :]] + + dist2 = multimer_square_euclidean_distance(q_point, k_point, epsilon=0.) + + attn_qk_point = -0.5 * mnp.sum(point_weights[:, None] * dist2, axis=-1) + attn_logits += attn_qk_point + + num_scalar_qk = self.num_scalar_qk + + scalar_weights = self.scalar_weights + q_scalar = self.q_scalar(inputs_1d) + q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk]) + + k_scalar = self.k_scalar(inputs_1d) + k_scalar = mnp.reshape(k_scalar, [num_residues, num_head, num_scalar_qk]) + + q_scalar *= scalar_weights + q = mnp.swapaxes(q_scalar, -2, -3) + k = mnp.swapaxes(k_scalar, -2, -3) + # k = self.trans(k, (0, 2, 1)) + attn_qk_scalar = self.bacth_matmul(q, mnp.swapaxes(k, -2, -1)) + # attn_qk_scalar = self.bacth_matmul(q, k) + attn_qk_scalar = mnp.swapaxes(attn_qk_scalar, -2, -3) + attn_qk_scalar = mnp.swapaxes(attn_qk_scalar, -2, -1) + # attn_qk_scalar = self.trans(attn_qk_scalar, (1, 2, 0)) + attn_logits += attn_qk_scalar + attention_2d = self.attention_2d(inputs_2d) + attn_logits += attention_2d + + sbr_act = self.process_sbr(sbr_act, sbr_mask, useperm=True) + attn_logits += sbr_act + + mask_2d = mask * mnp.swapaxes(mask, -1, -2) + attn_logits -= 1e5 * (1. - mask_2d[..., None]) # infer: 1e5, 50 + attn_logits *= mnp.sqrt(1. / 3) + attn = self.soft_max(attn_logits) + num_scalar_v = self.num_scalar_v + v_scalar = self.v_scalar(inputs_1d) + v_scalar = mnp.reshape(v_scalar, [num_residues, num_head, num_scalar_v]) + + attn_tmp = mnp.swapaxes(attn, -1, -2) + attn_tmp = mnp.swapaxes(attn_tmp, -2, -3) + # attn_tmp = P.Transpose()(attn, (2, 0, 1)) + result_scalar = self.bacth_matmul(attn_tmp, mnp.swapaxes(v_scalar, -2, -3)) + result_scalar = mnp.swapaxes(result_scalar, -2, -3) + + num_point_v = self.num_point_v + + v_point_local = self.v_point_local(inputs_1d) + v_point_local = mnp.reshape(v_point_local, (num_residues, num_head, num_point_v * 3)) + v_point_local = mnp.split(v_point_local, 3, axis=-1) + v_point_local = (ops.Squeeze()(v_point_local[0]), ops.Squeeze()(v_point_local[1]), + ops.Squeeze()(v_point_local[2])) + # # Project query points into global frame. + v_point_global = apply_to_point(rotation, translation, v_point_local, 2) + v_point = [v_point_global[0][None], v_point_global[1][None], v_point_global[2][None]] + + result_point_global = [mnp.sum(attn[..., None] * v_point[0], axis=-3), + mnp.sum(attn[..., None] * v_point[1], axis=-3), + mnp.sum(attn[..., None] * v_point[2], axis=-3) + ] + + num_query_residues, _ = inputs_1d.shape + + result_scalar = mnp.reshape(result_scalar, [num_query_residues, -1]) + + output_feature1 = result_scalar + result_point_global = [mnp.reshape(result_point_global[0], [num_query_residues, -1]), + mnp.reshape(result_point_global[1], [num_query_residues, -1]), + mnp.reshape(result_point_global[2], [num_query_residues, -1])] + result_point_local = invert_point(result_point_global, rotation, translation, 1) + output_feature20 = result_point_local[0] + output_feature21 = result_point_local[1] + output_feature22 = result_point_local[2] + point_norms = multimer_vecs_robust_norm(result_point_local, self._dist_epsilon) + output_feature3 = point_norms + + + result_attention_over_2d = self.bacth_matmul2(mnp.swapaxes(attn, 1, 2), inputs_2d) + output_feature4 = mnp.reshape(result_attention_over_2d, [num_query_residues, -1]) + # final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21, + # output_feature22, output_feature3, output_feature4], axis=-1) + # final_act = self.concat_e_6_2([P.Cast()(output_feature1, mstype.float32), output_feature20, output_feature21, + # output_feature22, output_feature3, P.Cast()(output_feature4, mstype.float32)]) + final_act = self.concat_e_6_2([output_feature1, output_feature20, output_feature21, + output_feature22, output_feature3, output_feature4]) + + final_result = self.output_projection(final_act) + return final_result \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/common/geometry.py b/MindSPONGE/applications/research/Grasp/common/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..7946de9971a6e905c0c21f87307ecb1791998331 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/common/geometry.py @@ -0,0 +1,155 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Geometry""" +import mindspore.numpy as mnp +from mindspore import ops, dtype +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +from mindsponge1.common.geometry import vecs_dot_vecs, vecs_sub, vecs_cross_vecs, \ + rots_expand_dims, vecs_expand_dims, invert_rigids, rigids_mul_vecs, \ + vecs_from_tensor, vecs_scale + + +def rots_mul_rots(r1, r2): + """rots_mul_rots.""" + out = (r1[0] * r2[0], r1[1] * r2[1], r1[2] * r2[2], + r1[3] * r2[3], r1[4] * r2[4], r1[5] * r2[5], + r1[6] * r2[6], r1[7] * r2[7], r1[8] * r2[8]) + return out + + +def trans_mul_trans(t1, t2): + """trans_mul_trans.""" + out = (t1[0] * t2[0], t1[1] * t2[1], t1[2] * t2[2]) + return out + + +# def multimer_vecs_robust_norm(v, epsilon=1e-6): +# """multime computes norm of vectors 'v'.""" +# v_l2_norm = v[0] * v[0] + v[1] * v[1] + v[2] * v[2] +# if epsilon: +# print("debug why this not work epsilon", epsilon) +# print("debug why this not work v_l2_norm", v_l2_norm) +# epsilon_new = ops.full(v_l2_norm.shape, 1e-6, dtype=dtype.float32) +# print("debug why this not work epsilon_new", epsilon_new) +# print("debug why this not work type", type(v_l2_norm), type(epsilon_new)) +# v_l2_norm2 = mnp.maximum(v_l2_norm, epsilon_new) +# print("debug why this not work v_l2_norm2", v_l2_norm2) +# print("debug why this not work v_l2_norm == v_l2_norm2", v_l2_norm == v_l2_norm2) +# return mnp.sqrt(v_l2_norm2) + + +def multimer_vecs_robust_norm(v, epsilon=1e-6): + """multime computes norm of vectors 'v'.""" + v_l2_norm = v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + if epsilon: + epsilon=1e-3 + v_l2_norm = F.maximum(v_l2_norm, epsilon**2) + return mnp.sqrt(v_l2_norm) + + +def multimer_vecs_robust_normalize(v, epsilon=1e-6): + """multimer normalizes vectors 'v'.""" + norms = multimer_vecs_robust_norm(v, epsilon) + return (v[0] / norms, v[1] / norms, v[2] / norms) + + +def multimer_rots_from_two_vecs(e0_unnormalized, e1_unnormalized): + """multimer_rots_from_two_vecs.""" + e0 = multimer_vecs_robust_normalize(e0_unnormalized) + c = vecs_dot_vecs(e1_unnormalized, e0) + e1 = vecs_sub(e1_unnormalized, vecs_scale(e0, c)) + + e1 = multimer_vecs_robust_normalize(e1) + e2 = vecs_cross_vecs(e0, e1) + + rots = (e0[0], e1[0], e2[0], + e0[1], e1[1], e2[1], + e0[2], e1[2], e2[2]) + return rots + + +def multimer_rigids_from_3_points(vec_a, vec_b, vec_c): + """Create multimer Rigids from 3 points. """ + m = multimer_rots_from_two_vecs( + e0_unnormalized=vecs_sub(vec_c, vec_b), + e1_unnormalized=vecs_sub(vec_a, vec_b)) + rigid = (m, vec_b) + return rigid + + +def multimer_rigids_get_unit_vector(point_a, point_b, point_c): + """multimer_rigids_get_unit_vector.""" + # print("debug point_a b c", + # "point_a", point_a, + # "point_b", point_b, + # "point_c", point_c) + rigid = multimer_rigids_from_3_points(vecs_from_tensor(point_a), + vecs_from_tensor(point_b), + vecs_from_tensor(point_c)) + rot, trans = rigid + rotation = rots_expand_dims(rot, -1) + translation = vecs_expand_dims(trans, -1) + inv_rigid = invert_rigids((rotation, translation)) + rigid_vec = rigids_mul_vecs(inv_rigid, vecs_expand_dims(trans, -2)) + unit_vector = multimer_vecs_robust_normalize(rigid_vec) + return unit_vector + + +def multimer_rigids_compute_dihedral_angle(a, b, c, d): + """multimer_rigids_compute_dihedral_angle.""" + v1 = vecs_sub(a, b) + v2 = vecs_sub(b, c) + v3 = vecs_sub(d, c) + + c1 = vecs_cross_vecs(v1, v2) + c2 = vecs_cross_vecs(v3, v2) + c3 = vecs_cross_vecs(c2, c1) + + v2_mag = multimer_vecs_robust_norm(v2) + return mnp.arctan2(vecs_dot_vecs(c3, v2), v2_mag * vecs_dot_vecs(c1, c2)) + + +def multimer_from_quaternion(w, x, y, z, normalize=True, epsilon=1e-6): + """multimer_from_quaternion.""" + if normalize: + inv_norm = P.Rsqrt()(mnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2)) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (mnp.square(y) + mnp.square(z)) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (mnp.square(x) + mnp.square(z)) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (mnp.square(x) + mnp.square(y)) + rots = (xx, xy, xz, + yx, yy, yz, + zx, zy, zz) + return rots + + +def multimer_square_euclidean_distance(v1, v2, epsilon): + """multimer_square_euclidean_distance.""" + difference = vecs_sub(v1, v2) + distance = vecs_dot_vecs(difference, difference) + if epsilon: + distance = F.maximum(distance, epsilon) + return distance \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/common/new_evo.txt b/MindSPONGE/applications/research/Grasp/common/new_evo.txt new file mode 100644 index 0000000000000000000000000000000000000000..d9acfedd3a860f5980153dc59c8fd146f050b585 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/common/new_evo.txt @@ -0,0 +1,110 @@ +msa_stack.msa_row_attention_with_pair_bias.query_norm_gammas +msa_stack.msa_row_attention_with_pair_bias.query_norm_betas +msa_stack.msa_row_attention_with_pair_bias.feat_2d_norm_gammas +msa_stack.msa_row_attention_with_pair_bias.feat_2d_norm_betas +msa_stack.msa_row_attention_with_pair_bias.feat_2d_weights +msa_stack.msa_row_attention_with_pair_bias.attn_mod.linear_q_weights +msa_stack.msa_row_attention_with_pair_bias.attn_mod.linear_k_weights +msa_stack.msa_row_attention_with_pair_bias.attn_mod.linear_v_weights +msa_stack.msa_row_attention_with_pair_bias.attn_mod.linear_output_weights +msa_stack.msa_row_attention_with_pair_bias.attn_mod.o_biases +msa_stack.msa_row_attention_with_pair_bias.attn_mod.linear_gating_weights +msa_stack.msa_row_attention_with_pair_bias.attn_mod.gating_biases +msa_stack.msa_transition.input_layer_norm_gammas +msa_stack.msa_transition.input_layer_norm_betas +msa_stack.msa_transition.transition1_weights +msa_stack.msa_transition.transition1_biases +msa_stack.msa_transition.transition2_weights +msa_stack.msa_transition.transition2_biases +msa_stack.outer_product_mean.layer_norm_input_gammas +msa_stack.outer_product_mean.layer_norm_input_betas +msa_stack.outer_product_mean.left_projection_weights +msa_stack.outer_product_mean.left_projection_biases +msa_stack.outer_product_mean.right_projection_weights +msa_stack.outer_product_mean.right_projection_biases +msa_stack.outer_product_mean.linear_output_weights +msa_stack.outer_product_mean.o_biases +msa_stack.triangle_attention_starting_node.query_norm_gammas +msa_stack.triangle_attention_starting_node.query_norm_betas +msa_stack.triangle_attention_starting_node.feat_2d_weights +msa_stack.triangle_attention_starting_node.attn_mod.linear_q_weights +msa_stack.triangle_attention_starting_node.attn_mod.linear_k_weights +msa_stack.triangle_attention_starting_node.attn_mod.linear_v_weights +msa_stack.triangle_attention_starting_node.attn_mod.linear_output_weights +msa_stack.triangle_attention_starting_node.attn_mod.o_biases +msa_stack.triangle_attention_starting_node.attn_mod.linear_gating_weights +msa_stack.triangle_attention_starting_node.attn_mod.gating_biases +msa_stack.triangle_attention_ending_node.query_norm_gammas +msa_stack.triangle_attention_ending_node.query_norm_betas +msa_stack.triangle_attention_ending_node.feat_2d_weights +msa_stack.triangle_attention_ending_node.attn_mod.linear_q_weights +msa_stack.triangle_attention_ending_node.attn_mod.linear_k_weights +msa_stack.triangle_attention_ending_node.attn_mod.linear_v_weights +msa_stack.triangle_attention_ending_node.attn_mod.linear_output_weights +msa_stack.triangle_attention_ending_node.attn_mod.o_biases +msa_stack.triangle_attention_ending_node.attn_mod.linear_gating_weights +msa_stack.triangle_attention_ending_node.attn_mod.gating_biases +msa_stack.pair_transition.input_layer_norm_gammas +msa_stack.pair_transition.input_layer_norm_betas +msa_stack.pair_transition.transition1_weights +msa_stack.pair_transition.transition1_biases +msa_stack.pair_transition.transition2_weights +msa_stack.pair_transition.transition2_biases +msa_stack.triangle_multiplication_outgoing.layer_norm_input_gammas +msa_stack.triangle_multiplication_outgoing.layer_norm_input_betas +msa_stack.triangle_multiplication_outgoing.left_projection_weights +msa_stack.triangle_multiplication_outgoing.left_projection_biases +msa_stack.triangle_multiplication_outgoing.right_projection_weights +msa_stack.triangle_multiplication_outgoing.right_projection_biases +msa_stack.triangle_multiplication_outgoing.left_gate_weights +msa_stack.triangle_multiplication_outgoing.left_gate_biases +msa_stack.triangle_multiplication_outgoing.right_gate_weights +msa_stack.triangle_multiplication_outgoing.right_gate_biases +msa_stack.triangle_multiplication_outgoing.center_layer_norm_gammas +msa_stack.triangle_multiplication_outgoing.center_layer_norm_betas +msa_stack.triangle_multiplication_outgoing.output_projection_weights +msa_stack.triangle_multiplication_outgoing.output_projection_biases +msa_stack.triangle_multiplication_outgoing.gating_linear_weights +msa_stack.triangle_multiplication_outgoing.gating_linear_biases +msa_stack.triangle_multiplication_incoming.layer_norm_input_gammas +msa_stack.triangle_multiplication_incoming.layer_norm_input_betas +msa_stack.triangle_multiplication_incoming.left_projection_weights +msa_stack.triangle_multiplication_incoming.left_projection_biases +msa_stack.triangle_multiplication_incoming.right_projection_weights +msa_stack.triangle_multiplication_incoming.right_projection_biases +msa_stack.triangle_multiplication_incoming.left_gate_weights +msa_stack.triangle_multiplication_incoming.left_gate_biases +msa_stack.triangle_multiplication_incoming.right_gate_weights +msa_stack.triangle_multiplication_incoming.right_gate_biases +msa_stack.triangle_multiplication_incoming.center_layer_norm_gammas +msa_stack.triangle_multiplication_incoming.center_layer_norm_betas +msa_stack.triangle_multiplication_incoming.output_projection_weights +msa_stack.triangle_multiplication_incoming.output_projection_biases +msa_stack.triangle_multiplication_incoming.gating_linear_weights +msa_stack.triangle_multiplication_incoming.gating_linear_biases +msa_stack.attn_mod.query_norm_gammas +msa_stack.attn_mod.query_norm_betas +msa_stack.attn_mod.attn_mod.linear_q_weights +msa_stack.attn_mod.attn_mod.linear_k_weights +msa_stack.attn_mod.attn_mod.linear_v_weights +msa_stack.attn_mod.attn_mod.linear_output_weights +msa_stack.attn_mod.attn_mod.o_biases +msa_stack.attn_mod.attn_mod.linear_gating_weights +msa_stack.attn_mod.attn_mod.gating_biases +msa_stack.msa_row_attention_with_pair_bias.contact_norm_gammas +msa_stack.msa_row_attention_with_pair_bias.contact_norm_betas +msa_stack.msa_row_attention_with_pair_bias.contact_weights +msa_stack.msa_row_attention_with_pair_bias.sbr_norm_gammas +msa_stack.msa_row_attention_with_pair_bias.sbr_norm_betas +msa_stack.msa_row_attention_with_pair_bias.sbr_weights +msa_stack.msa_row_attention_with_pair_bias.add_interface.input_layer_norm_gammas +msa_stack.msa_row_attention_with_pair_bias.add_interface.input_layer_norm_betas +msa_stack.msa_row_attention_with_pair_bias.add_interface.linear_weights +msa_stack.msa_row_attention_with_pair_bias.add_interface.linear_biases +msa_stack.msa_row_attention_with_pair_bias.preprocess_sbr.input_layer_norm_gammas +msa_stack.msa_row_attention_with_pair_bias.preprocess_sbr.input_layer_norm_betas +msa_stack.msa_row_attention_with_pair_bias.preprocess_sbr.linear_weights +msa_stack.preprocess_sbr.input_layer_norm_gammas +msa_stack.preprocess_sbr.input_layer_norm_betas +msa_stack.preprocess_sbr.linear_weights +msa_stack.preprocess_sbr.linear_biases diff --git a/MindSPONGE/applications/research/Grasp/common/new_extra.txt b/MindSPONGE/applications/research/Grasp/common/new_extra.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a3cf374e7b5d53028836d3910ac5294e4f2c548 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/common/new_extra.txt @@ -0,0 +1,93 @@ +extra_msa_stack.0.msa_row_attention_with_pair_bias.query_norm_gammas +extra_msa_stack.0.msa_row_attention_with_pair_bias.query_norm_betas +extra_msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_norm_gammas +extra_msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_norm_betas +extra_msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_q_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_k_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_v_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_output_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.o_biases +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_gating_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.gating_biases +extra_msa_stack.0.msa_transition.input_layer_norm_gammas +extra_msa_stack.0.msa_transition.input_layer_norm_betas +extra_msa_stack.0.msa_transition.transition1_weights +extra_msa_stack.0.msa_transition.transition1_biases +extra_msa_stack.0.msa_transition.transition2_weights +extra_msa_stack.0.msa_transition.transition2_biases +extra_msa_stack.0.outer_product_mean.layer_norm_input_gammas +extra_msa_stack.0.outer_product_mean.layer_norm_input_betas +extra_msa_stack.0.outer_product_mean.left_projection_weights +extra_msa_stack.0.outer_product_mean.left_projection_biases +extra_msa_stack.0.outer_product_mean.right_projection_weights +extra_msa_stack.0.outer_product_mean.right_projection_biases +extra_msa_stack.0.outer_product_mean.linear_output_weights +extra_msa_stack.0.outer_product_mean.o_biases +extra_msa_stack.0.triangle_attention_starting_node.query_norm_gammas +extra_msa_stack.0.triangle_attention_starting_node.query_norm_betas +extra_msa_stack.0.triangle_attention_starting_node.feat_2d_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_q_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_k_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_v_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_output_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.o_biases +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_gating_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.gating_biases +extra_msa_stack.0.triangle_attention_ending_node.query_norm_gammas +extra_msa_stack.0.triangle_attention_ending_node.query_norm_betas +extra_msa_stack.0.triangle_attention_ending_node.feat_2d_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_q_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_k_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_v_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_output_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.o_biases +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_gating_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.gating_biases +extra_msa_stack.0.pair_transition.input_layer_norm_gammas +extra_msa_stack.0.pair_transition.input_layer_norm_betas +extra_msa_stack.0.pair_transition.transition1_weights +extra_msa_stack.0.pair_transition.transition1_biases +extra_msa_stack.0.pair_transition.transition2_weights +extra_msa_stack.0.pair_transition.transition2_biases +extra_msa_stack.0.triangle_multiplication_outgoing.layer_norm_input_gammas +extra_msa_stack.0.triangle_multiplication_outgoing.layer_norm_input_betas +extra_msa_stack.0.triangle_multiplication_outgoing.left_projection_weights +extra_msa_stack.0.triangle_multiplication_outgoing.left_projection_biases +extra_msa_stack.0.triangle_multiplication_outgoing.right_projection_weights +extra_msa_stack.0.triangle_multiplication_outgoing.right_projection_biases +extra_msa_stack.0.triangle_multiplication_outgoing.left_gate_weights +extra_msa_stack.0.triangle_multiplication_outgoing.left_gate_biases +extra_msa_stack.0.triangle_multiplication_outgoing.right_gate_weights +extra_msa_stack.0.triangle_multiplication_outgoing.right_gate_biases +extra_msa_stack.0.triangle_multiplication_outgoing.center_layer_norm_gammas +extra_msa_stack.0.triangle_multiplication_outgoing.center_layer_norm_betas +extra_msa_stack.0.triangle_multiplication_outgoing.output_projection_weights +extra_msa_stack.0.triangle_multiplication_outgoing.output_projection_biases +extra_msa_stack.0.triangle_multiplication_outgoing.gating_linear_weights +extra_msa_stack.0.triangle_multiplication_outgoing.gating_linear_biases +extra_msa_stack.0.triangle_multiplication_incoming.layer_norm_input_gammas +extra_msa_stack.0.triangle_multiplication_incoming.layer_norm_input_betas +extra_msa_stack.0.triangle_multiplication_incoming.left_projection_weights +extra_msa_stack.0.triangle_multiplication_incoming.left_projection_biases +extra_msa_stack.0.triangle_multiplication_incoming.right_projection_weights +extra_msa_stack.0.triangle_multiplication_incoming.right_projection_biases +extra_msa_stack.0.triangle_multiplication_incoming.left_gate_weights +extra_msa_stack.0.triangle_multiplication_incoming.left_gate_biases +extra_msa_stack.0.triangle_multiplication_incoming.right_gate_weights +extra_msa_stack.0.triangle_multiplication_incoming.right_gate_biases +extra_msa_stack.0.triangle_multiplication_incoming.center_layer_norm_gammas +extra_msa_stack.0.triangle_multiplication_incoming.center_layer_norm_betas +extra_msa_stack.0.triangle_multiplication_incoming.output_projection_weights +extra_msa_stack.0.triangle_multiplication_incoming.output_projection_biases +extra_msa_stack.0.triangle_multiplication_incoming.gating_linear_weights +extra_msa_stack.0.triangle_multiplication_incoming.gating_linear_biases +extra_msa_stack.0.attn_mod.query_norm_gammas +extra_msa_stack.0.attn_mod.query_norm_betas +extra_msa_stack.0.attn_mod.attn_mod.linear_q_weights +extra_msa_stack.0.attn_mod.attn_mod.linear_k_weights +extra_msa_stack.0.attn_mod.attn_mod.linear_v_weights +extra_msa_stack.0.attn_mod.attn_mod.linear_output_weights +extra_msa_stack.0.attn_mod.attn_mod.o_biases +extra_msa_stack.0.attn_mod.attn_mod.linear_gating_weights +extra_msa_stack.0.attn_mod.attn_mod.gating_biases \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/common/old_evo.txt b/MindSPONGE/applications/research/Grasp/common/old_evo.txt new file mode 100644 index 0000000000000000000000000000000000000000..0f16f95cea5518cfaeaab94dd7776c42baa37cfa --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/common/old_evo.txt @@ -0,0 +1,110 @@ +msa_stack.0.msa_row_attention_with_pair_bias.query_norm_gammas +msa_stack.0.msa_row_attention_with_pair_bias.query_norm_betas +msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_norm_gammas +msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_norm_betas +msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_weights +msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_q_weights +msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_k_weights +msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_v_weights +msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_output_weights +msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.o_biases +msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_gating_weights +msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.gating_biases +msa_stack.0.msa_transition.input_layer_norm_gammas +msa_stack.0.msa_transition.input_layer_norm_betas +msa_stack.0.msa_transition.transition1_weights +msa_stack.0.msa_transition.transition1_biases +msa_stack.0.msa_transition.transition2_weights +msa_stack.0.msa_transition.transition2_biases +msa_stack.0.outer_product_mean.layer_norm_input_gammas +msa_stack.0.outer_product_mean.layer_norm_input_betas +msa_stack.0.outer_product_mean.left_projection_weights +msa_stack.0.outer_product_mean.left_projection_biases +msa_stack.0.outer_product_mean.right_projection_weights +msa_stack.0.outer_product_mean.right_projection_biases +msa_stack.0.outer_product_mean.linear_output_weights +msa_stack.0.outer_product_mean.o_biases +msa_stack.0.triangle_attention_starting_node.query_norm_gammas +msa_stack.0.triangle_attention_starting_node.query_norm_betas +msa_stack.0.triangle_attention_starting_node.feat_2d_weights +msa_stack.0.triangle_attention_starting_node.attn_mod.linear_q_weights +msa_stack.0.triangle_attention_starting_node.attn_mod.linear_k_weights +msa_stack.0.triangle_attention_starting_node.attn_mod.linear_v_weights +msa_stack.0.triangle_attention_starting_node.attn_mod.linear_output_weights +msa_stack.0.triangle_attention_starting_node.attn_mod.o_biases +msa_stack.0.triangle_attention_starting_node.attn_mod.linear_gating_weights +msa_stack.0.triangle_attention_starting_node.attn_mod.gating_biases +msa_stack.0.triangle_attention_ending_node.query_norm_gammas +msa_stack.0.triangle_attention_ending_node.query_norm_betas +msa_stack.0.triangle_attention_ending_node.feat_2d_weights +msa_stack.0.triangle_attention_ending_node.attn_mod.linear_q_weights +msa_stack.0.triangle_attention_ending_node.attn_mod.linear_k_weights +msa_stack.0.triangle_attention_ending_node.attn_mod.linear_v_weights +msa_stack.0.triangle_attention_ending_node.attn_mod.linear_output_weights +msa_stack.0.triangle_attention_ending_node.attn_mod.o_biases +msa_stack.0.triangle_attention_ending_node.attn_mod.linear_gating_weights +msa_stack.0.triangle_attention_ending_node.attn_mod.gating_biases +msa_stack.0.pair_transition.input_layer_norm_gammas +msa_stack.0.pair_transition.input_layer_norm_betas +msa_stack.0.pair_transition.transition1_weights +msa_stack.0.pair_transition.transition1_biases +msa_stack.0.pair_transition.transition2_weights +msa_stack.0.pair_transition.transition2_biases +msa_stack.0.triangle_multiplication_outgoing.layer_norm_input_gammas +msa_stack.0.triangle_multiplication_outgoing.layer_norm_input_betas +msa_stack.0.triangle_multiplication_outgoing.left_projection_weights +msa_stack.0.triangle_multiplication_outgoing.left_projection_biases +msa_stack.0.triangle_multiplication_outgoing.right_projection_weights +msa_stack.0.triangle_multiplication_outgoing.right_projection_biases +msa_stack.0.triangle_multiplication_outgoing.left_gate_weights +msa_stack.0.triangle_multiplication_outgoing.left_gate_biases +msa_stack.0.triangle_multiplication_outgoing.right_gate_weights +msa_stack.0.triangle_multiplication_outgoing.right_gate_biases +msa_stack.0.triangle_multiplication_outgoing.center_layer_norm_gammas +msa_stack.0.triangle_multiplication_outgoing.center_layer_norm_betas +msa_stack.0.triangle_multiplication_outgoing.output_projection_weights +msa_stack.0.triangle_multiplication_outgoing.output_projection_biases +msa_stack.0.triangle_multiplication_outgoing.gating_linear_weights +msa_stack.0.triangle_multiplication_outgoing.gating_linear_biases +msa_stack.0.triangle_multiplication_incoming.layer_norm_input_gammas +msa_stack.0.triangle_multiplication_incoming.layer_norm_input_betas +msa_stack.0.triangle_multiplication_incoming.left_projection_weights +msa_stack.0.triangle_multiplication_incoming.left_projection_biases +msa_stack.0.triangle_multiplication_incoming.right_projection_weights +msa_stack.0.triangle_multiplication_incoming.right_projection_biases +msa_stack.0.triangle_multiplication_incoming.left_gate_weights +msa_stack.0.triangle_multiplication_incoming.left_gate_biases +msa_stack.0.triangle_multiplication_incoming.right_gate_weights +msa_stack.0.triangle_multiplication_incoming.right_gate_biases +msa_stack.0.triangle_multiplication_incoming.center_layer_norm_gammas +msa_stack.0.triangle_multiplication_incoming.center_layer_norm_betas +msa_stack.0.triangle_multiplication_incoming.output_projection_weights +msa_stack.0.triangle_multiplication_incoming.output_projection_biases +msa_stack.0.triangle_multiplication_incoming.gating_linear_weights +msa_stack.0.triangle_multiplication_incoming.gating_linear_biases +msa_stack.0.attn_mod.query_norm_gammas +msa_stack.0.attn_mod.query_norm_betas +msa_stack.0.attn_mod.attn_mod.linear_q_weights +msa_stack.0.attn_mod.attn_mod.linear_k_weights +msa_stack.0.attn_mod.attn_mod.linear_v_weights +msa_stack.0.attn_mod.attn_mod.linear_output_weights +msa_stack.0.attn_mod.attn_mod.o_biases +msa_stack.0.attn_mod.attn_mod.linear_gating_weights +msa_stack.0.attn_mod.attn_mod.gating_biases +msa_stack.0.msa_row_attention_with_pair_bias.contact_norm_gammas +msa_stack.0.msa_row_attention_with_pair_bias.contact_norm_betas +msa_stack.0.msa_row_attention_with_pair_bias.contact_weights +msa_stack.0.msa_row_attention_with_pair_bias.sbr_norm_gammas +msa_stack.0.msa_row_attention_with_pair_bias.sbr_norm_betas +msa_stack.0.msa_row_attention_with_pair_bias.sbr_weights +msa_stack.0.msa_row_attention_with_pair_bias.add_interface.input_layer_norm_gammas +msa_stack.0.msa_row_attention_with_pair_bias.add_interface.input_layer_norm_betas +msa_stack.0.msa_row_attention_with_pair_bias.add_interface.linear_weights +msa_stack.0.msa_row_attention_with_pair_bias.add_interface.linear_biases +msa_stack.0.msa_row_attention_with_pair_bias.preprocess_sbr.input_layer_norm_gammas +msa_stack.0.msa_row_attention_with_pair_bias.preprocess_sbr.input_layer_norm_betas +msa_stack.0.msa_row_attention_with_pair_bias.preprocess_sbr.linear_weights +msa_stack.0.preprocess_sbr.input_layer_norm_gammas +msa_stack.0.preprocess_sbr.input_layer_norm_betas +msa_stack.0.preprocess_sbr.linear_weights +msa_stack.0.preprocess_sbr.linear_biases diff --git a/MindSPONGE/applications/research/Grasp/common/old_extra.txt b/MindSPONGE/applications/research/Grasp/common/old_extra.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a3cf374e7b5d53028836d3910ac5294e4f2c548 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/common/old_extra.txt @@ -0,0 +1,93 @@ +extra_msa_stack.0.msa_row_attention_with_pair_bias.query_norm_gammas +extra_msa_stack.0.msa_row_attention_with_pair_bias.query_norm_betas +extra_msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_norm_gammas +extra_msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_norm_betas +extra_msa_stack.0.msa_row_attention_with_pair_bias.feat_2d_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_q_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_k_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_v_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_output_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.o_biases +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.linear_gating_weights +extra_msa_stack.0.msa_row_attention_with_pair_bias.attn_mod.gating_biases +extra_msa_stack.0.msa_transition.input_layer_norm_gammas +extra_msa_stack.0.msa_transition.input_layer_norm_betas +extra_msa_stack.0.msa_transition.transition1_weights +extra_msa_stack.0.msa_transition.transition1_biases +extra_msa_stack.0.msa_transition.transition2_weights +extra_msa_stack.0.msa_transition.transition2_biases +extra_msa_stack.0.outer_product_mean.layer_norm_input_gammas +extra_msa_stack.0.outer_product_mean.layer_norm_input_betas +extra_msa_stack.0.outer_product_mean.left_projection_weights +extra_msa_stack.0.outer_product_mean.left_projection_biases +extra_msa_stack.0.outer_product_mean.right_projection_weights +extra_msa_stack.0.outer_product_mean.right_projection_biases +extra_msa_stack.0.outer_product_mean.linear_output_weights +extra_msa_stack.0.outer_product_mean.o_biases +extra_msa_stack.0.triangle_attention_starting_node.query_norm_gammas +extra_msa_stack.0.triangle_attention_starting_node.query_norm_betas +extra_msa_stack.0.triangle_attention_starting_node.feat_2d_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_q_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_k_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_v_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_output_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.o_biases +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.linear_gating_weights +extra_msa_stack.0.triangle_attention_starting_node.attn_mod.gating_biases +extra_msa_stack.0.triangle_attention_ending_node.query_norm_gammas +extra_msa_stack.0.triangle_attention_ending_node.query_norm_betas +extra_msa_stack.0.triangle_attention_ending_node.feat_2d_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_q_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_k_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_v_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_output_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.o_biases +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.linear_gating_weights +extra_msa_stack.0.triangle_attention_ending_node.attn_mod.gating_biases +extra_msa_stack.0.pair_transition.input_layer_norm_gammas +extra_msa_stack.0.pair_transition.input_layer_norm_betas +extra_msa_stack.0.pair_transition.transition1_weights +extra_msa_stack.0.pair_transition.transition1_biases +extra_msa_stack.0.pair_transition.transition2_weights +extra_msa_stack.0.pair_transition.transition2_biases +extra_msa_stack.0.triangle_multiplication_outgoing.layer_norm_input_gammas +extra_msa_stack.0.triangle_multiplication_outgoing.layer_norm_input_betas +extra_msa_stack.0.triangle_multiplication_outgoing.left_projection_weights +extra_msa_stack.0.triangle_multiplication_outgoing.left_projection_biases +extra_msa_stack.0.triangle_multiplication_outgoing.right_projection_weights +extra_msa_stack.0.triangle_multiplication_outgoing.right_projection_biases +extra_msa_stack.0.triangle_multiplication_outgoing.left_gate_weights +extra_msa_stack.0.triangle_multiplication_outgoing.left_gate_biases +extra_msa_stack.0.triangle_multiplication_outgoing.right_gate_weights +extra_msa_stack.0.triangle_multiplication_outgoing.right_gate_biases +extra_msa_stack.0.triangle_multiplication_outgoing.center_layer_norm_gammas +extra_msa_stack.0.triangle_multiplication_outgoing.center_layer_norm_betas +extra_msa_stack.0.triangle_multiplication_outgoing.output_projection_weights +extra_msa_stack.0.triangle_multiplication_outgoing.output_projection_biases +extra_msa_stack.0.triangle_multiplication_outgoing.gating_linear_weights +extra_msa_stack.0.triangle_multiplication_outgoing.gating_linear_biases +extra_msa_stack.0.triangle_multiplication_incoming.layer_norm_input_gammas +extra_msa_stack.0.triangle_multiplication_incoming.layer_norm_input_betas +extra_msa_stack.0.triangle_multiplication_incoming.left_projection_weights +extra_msa_stack.0.triangle_multiplication_incoming.left_projection_biases +extra_msa_stack.0.triangle_multiplication_incoming.right_projection_weights +extra_msa_stack.0.triangle_multiplication_incoming.right_projection_biases +extra_msa_stack.0.triangle_multiplication_incoming.left_gate_weights +extra_msa_stack.0.triangle_multiplication_incoming.left_gate_biases +extra_msa_stack.0.triangle_multiplication_incoming.right_gate_weights +extra_msa_stack.0.triangle_multiplication_incoming.right_gate_biases +extra_msa_stack.0.triangle_multiplication_incoming.center_layer_norm_gammas +extra_msa_stack.0.triangle_multiplication_incoming.center_layer_norm_betas +extra_msa_stack.0.triangle_multiplication_incoming.output_projection_weights +extra_msa_stack.0.triangle_multiplication_incoming.output_projection_biases +extra_msa_stack.0.triangle_multiplication_incoming.gating_linear_weights +extra_msa_stack.0.triangle_multiplication_incoming.gating_linear_biases +extra_msa_stack.0.attn_mod.query_norm_gammas +extra_msa_stack.0.attn_mod.query_norm_betas +extra_msa_stack.0.attn_mod.attn_mod.linear_q_weights +extra_msa_stack.0.attn_mod.attn_mod.linear_k_weights +extra_msa_stack.0.attn_mod.attn_mod.linear_v_weights +extra_msa_stack.0.attn_mod.attn_mod.linear_output_weights +extra_msa_stack.0.attn_mod.attn_mod.o_biases +extra_msa_stack.0.attn_mod.attn_mod.linear_gating_weights +extra_msa_stack.0.attn_mod.attn_mod.gating_biases \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/common/protein.py b/MindSPONGE/applications/research/Grasp/common/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..fe32958733f942ffb7c748a9adbfbf6c87a0f68b --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/common/protein.py @@ -0,0 +1,190 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""tein""" +from typing import Any, Mapping +import dataclasses + +import numpy as np + +from mindsponge1.common import residue_constants + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + +PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + chain_index = prot.chain_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append('MODEL 1') + atom_index = 1 + last_chain_index = chain_index[0] + # Add all atom sites. + for i in range(aatype.shape[0]): + if last_chain_index != chain_index[i]: + chain_end = 'TER' + chain_termination_line = ( + f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[i - 1]):>3} ' + f'{chain_ids[chain_index[i - 1]]:>1}{residue_index[i - 1]:>4}') + pdb_lines.append(chain_termination_line) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the chain. + chain_end = 'TER' + chain_termination_line = ( + f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} ' + f'{chain_ids[chain_index[-1]]:>1}{residue_index[-1]:>4}') + pdb_lines.append(chain_termination_line) + pdb_lines.append('ENDMDL') + + pdb_lines.append('END') + pdb_lines.append('') + return '\n'.join(pdb_lines) + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction(final_atom_positions, + final_atom_mask, + aatype, + residue_index, + b_factors=None, + asym_id=None, + remove_leading_feature_dimension=True) -> Protein: + """Assembles a protein from a prediction. + + Args: + final_atom_positions: atom positions + final_atom_mask: atom mask + aatype: amino acid type + residue_index: idx of the residue + Returns: + A protein instance. + """ + def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: + return arr[0] if remove_leading_feature_dimension else arr + + if asym_id is not None: + chain_index = _maybe_remove_leading_dim(asym_id) + else: + chain_index = np.zeros_like(aatype) + if b_factors is None: + b_factors = np.zeros_like(final_atom_mask) + + return Protein( + aatype=aatype, + atom_positions=final_atom_positions, + atom_mask=final_atom_mask, + residue_index=residue_index + 1, + chain_index=chain_index, + b_factors=b_factors) diff --git a/MindSPONGE/applications/research/Grasp/common/utils.py b/MindSPONGE/applications/research/Grasp/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbd64f6fe9735414a61c06e603fc63ec9ad454f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/common/utils.py @@ -0,0 +1,309 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""utils module""" + +import mindspore.numpy as mnp +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindsponge1.common.geometry import vecs_from_tensor +from common.geometry import multimer_rigids_compute_dihedral_angle +from mindspore import Parameter +import re + +def trans_ckpt(param_dict): + # tmp_dict = {k: v.asnumpy() for k, v in param_dict.items()} + # import pickle + # with open('/job/file/step0_numpy.ckpt', 'wb') as f: + # pickle.dump(tmp_dict, f) + + # raise IOError('good bye') + + new_param_dict = {} + for k, v in param_dict.items(): + if re.search('learning_rate|global_step|moment[12]|beta[12]_power|vhat|template_embedding\._flat_.*_slice', k): + continue + if re.search('^msa_stack', k): + new_k = re.sub('^(msa_stack\.)\d+\.', '\\1', k) + if new_k in new_param_dict: + new_param_dict[new_k].append(v.asnumpy()[None]) + else: + new_param_dict[new_k] = [v.asnumpy()[None]] + else: + new_param_dict[k] = v + for k, v in new_param_dict.items(): + if re.search('^msa_stack', k): + new_param_dict[k] = Parameter(np.concatenate(new_param_dict[k], axis=0)) + for key, value in new_param_dict.items(): + if (('preprocess_1d.weight' in key) or ('left_single.weight' in key) or ('right_single.weight' in key)) and (new_param_dict[key].shape[-1] == 22): + new_param_dict[key] = Parameter(new_param_dict[key][..., 1:]) + return new_param_dict + + + +# def trans_ckpt(ckpt): +# # temp_key = [] +# current_path = "/job/file/common/" + +# batch_dict = {} + +# msa_key = [] +# with open(current_path+"old_extra.txt", "r") as f: +# for line in f.readlines(): +# msa_key.append(line.strip('\n')) +# msa_keys = [] +# for i in range(4): +# temp = [] +# for j in range(len(msa_key)): +# key = msa_key[j].split('0') +# new_key = key[0] + str(i) + key[1] +# temp.append(new_key) +# msa_keys.append(temp) + +# msa_new_key = [] +# with open(current_path+"new_extra.txt", "r") as f: +# for line in f.readlines(): +# msa_new_key.append(line.strip('\n')) +# msa_new_keys = [] +# for i in range(4): +# temp = [] +# for j in range(len(msa_new_key)): +# key = msa_new_key[j].split('0') +# new_key = key[0] + str(i) + key[1] +# temp.append(new_key) +# msa_new_keys.append(temp) + +# envo_key = [] +# with open(current_path+"old_evo.txt", "r") as f: +# for line in f.readlines(): +# envo_key.append(line.strip('\n')) +# envo_keys = [] +# for i in range(48): +# temp = [] +# for j in range(len(envo_key)): +# key = envo_key[j].split('0') +# new_key = key[0] + str(i) + key[1] +# temp.append(new_key) +# envo_keys.append(temp) + +# envo_new_key = [] +# with open(current_path+"new_evo.txt", "r") as f: +# for line in f.readlines(): +# envo_new_key.append(line.strip('\n')) +# envo_new_keys = [] +# for i in range(1): +# temp = [] +# for j in range(len(envo_new_key)): +# new_key = envo_new_key[j] +# temp.append(new_key) +# envo_new_keys.append(temp) +# for key in ckpt.keys(): +# flat_msa_keys = sum(msa_keys, []) +# flat_envo_keys = sum(envo_keys, []) +# msa_count = len(msa_keys[0]) +# envo_count = len(envo_keys[0]) +# if "learning_rate" in key or "global_step" in key or "moment1" in key or "moment2" in key or "beta1_power" in key or "beta2_power" in key or "vhat" in key: +# continue +# if key in flat_msa_keys: +# row = flat_msa_keys.index(key) // msa_count +# col = flat_msa_keys.index(key) % msa_count +# batch_dict[msa_new_keys[row][col]] = ckpt[key] +# elif key in flat_envo_keys: +# row = flat_envo_keys.index(key) // envo_count +# col = flat_envo_keys.index(key) % envo_count +# if envo_new_keys[0][col] not in batch_dict: +# batch_dict[envo_new_keys[0][col]] = np.array(np.expand_dims(ckpt[key].asnumpy(), 0)) +# else: +# batch_dict[envo_new_keys[0][col]] = np.array(np.concatenate((batch_dict[envo_new_keys[0][col]], np.expand_dims(ckpt[key].asnumpy(), 0)), axis=0)) +# else: +# batch_dict[key] = ckpt[key] + +# for k, v in batch_dict.items(): +# # print(k, v.shape, flush=True) +# if 'template_embedding._flat_query_slice' in k or 'template_embedding._flat_templates_slice' in k: +# continue +# batch_dict[k] = Parameter(v) + + +# return batch_dict + + +class CompuyeChiAngles(nn.Cell): + def __init__(self): + super(CompuyeChiAngles, self).__init__() + self.equal = P.Equal() + self.minimum = P.Minimum().shard(((1,2),())) + self.reshape = P.Reshape().shard(((1,2,1,1),())) + self.concat = P.Concat(4).shard(((1, 2, 1, 1, 1), (1, 2, 1, 1, 1),(1, 2, 1, 1, 1))) + self.gathernd1 = P.GatherNd()#.shard(((1,8,1,1),(1,8,1,1,1))) + self.gathernd2 = P.GatherNd()#.shard(((1,8,1), (1,8,1,1,1))) + self.reduceprod = P.ReduceProd().shard(((1,2,1,1),)) + self.mul = P.Mul().shard(((1,2,1),(1,2,1))) + self.stack = P.Stack().shard(((2,1),(2,1),(2,1),(2,1))) + + + def construct(self, aatype, # (B, N) + all_atom_pos, # (B, N, 37, 3) + all_atom_mask, # (B, N, 37) + chi_atom_indices, + chi_angles_mask, + indices0, + indices1): + aatype = self.minimum(aatype, 20) + # Collect the atoms for the chi-angles. + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. + atom_indices = mnp.take(chi_atom_indices, aatype, axis=0) + + # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. + + # 4 seq_length 4 4 batch, sequence length, chis, atoms + seq_length = all_atom_pos.shape[1] + atom_indices = self.reshape(atom_indices, tuple((4, seq_length, 4, 4, 1))).astype("int32") + new_indices = self.concat((indices0, indices1, atom_indices)) + chis_atom_pos = self.gathernd1(all_atom_pos, new_indices) + chis_mask = mnp.take(chi_angles_mask, aatype, axis=0) + chi_angle_atoms_mask = self.gathernd2(all_atom_mask, new_indices) + # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. + chi_angle_atoms_mask = self.reduceprod(chi_angle_atoms_mask, -1) + chis_mask = self.mul(chis_mask, (chi_angle_atoms_mask).astype(mnp.float32)) + all_chi_angles = [] + for i in range(aatype.shape[0]): + template_chi_angles = multimer_rigids_compute_dihedral_angle(vecs_from_tensor(chis_atom_pos[i, :, :, 0, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 1, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 2, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 3, :])) + all_chi_angles.append(template_chi_angles) + chi_angles = self.stack(all_chi_angles) + return chi_angles, chis_mask + + +def compute_chi_angles(aatype, # (B, N) + all_atom_pos, # (B, N, 37, 3) + all_atom_mask, # (B, N, 37) + chi_atom_indices, + chi_angles_mask, + indices0, + indices1): + """compute chi angles""" + + aatype = mnp.minimum(aatype, 20) + # Collect the atoms for the chi-angles. + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. + # atom_indices = mnp.take(chi_atom_indices, aatype, axis=0) + atom_indices = chi_atom_indices[aatype,...] + # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. + + # 4 seq_length 4 4 batch, sequence length, chis, atoms + seq_length = all_atom_pos.shape[1] + atom_indices = atom_indices.reshape((4, seq_length, 4, 4, 1)).astype("int32") + new_indices = P.Concat(4)((indices0, indices1, atom_indices)) + chis_atom_pos = P.GatherNd()(all_atom_pos, new_indices) + # chis_mask = mnp.take(chi_angles_mask, aatype, axis=0) + chis_mask = chi_angles_mask[aatype,:] + chi_angle_atoms_mask = P.GatherNd()(all_atom_mask, new_indices) + + # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. + chi_angle_atoms_mask = P.ReduceProd()(chi_angle_atoms_mask, -1) + chis_mask = chis_mask * (chi_angle_atoms_mask).astype(mnp.float32) + all_chi_angles = [] + for i in range(aatype.shape[0]): + template_chi_angles = multimer_rigids_compute_dihedral_angle(vecs_from_tensor(chis_atom_pos[i, :, :, 0, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 1, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 2, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 3, :])) + all_chi_angles.append(template_chi_angles) + chi_angles = mnp.stack(all_chi_angles, axis=0) + return chi_angles, chis_mask + + +class ComputeChiAngles(nn.Cell): + def __init__(self, device_num): + super(ComputeChiAngles, self).__init__() + self.equal = P.Equal() + self.minimum = P.Minimum().shard(((1,device_num),())) + self.reshape = P.Reshape().shard(((1,device_num,1,1),())) + self.concat = P.Concat(4).shard(((1, device_num, 1, 1, 1), (1, device_num, 1, 1, 1),(1, device_num, 1, 1, 1))) + self.gathernd1 = P.GatherNd()#.shard(((1,8,1,1),(1,8,1,1,1))) + self.gathernd2 = P.GatherNd()#.shard(((1,8,1), (1,8,1,1,1))) + self.reduceprod = P.ReduceProd().shard(((1,device_num,1,1),)) + self.mul = P.Mul().shard(((1,device_num,1),(1,device_num,1))) + self.stack = P.Stack().shard(((device_num,1),(device_num,1),(device_num,1),(device_num,1))) + def construct(self, aatype, # (B, N) + all_atom_pos, # (B, N, 37, 3) + all_atom_mask, # (B, N, 37) + chi_atom_indices, + chi_angles_mask, + indices0, + indices1): + aatype = self.minimum(aatype, 20) + # Collect the atoms for the chi-angles. + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. + atom_indices = mnp.take(chi_atom_indices, aatype, axis=0) + + # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. + + # 4 seq_length 4 4 batch, sequence length, chis, atoms + seq_length = all_atom_pos.shape[1] + atom_indices = self.reshape(atom_indices, tuple((4, seq_length, 4, 4, 1))).astype("int32") + new_indices = self.concat((indices0, indices1, atom_indices)) + chis_atom_pos = self.gathernd1(all_atom_pos, new_indices) + chis_mask = mnp.take(chi_angles_mask, aatype, axis=0) + chi_angle_atoms_mask = self.gathernd2(all_atom_mask, new_indices) + # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. + chi_angle_atoms_mask = self.reduceprod(chi_angle_atoms_mask, -1) + chis_mask = self.mul(chis_mask, (chi_angle_atoms_mask).astype(mnp.float32)) + all_chi_angles = [] + for i in range(aatype.shape[0]): + template_chi_angles = multimer_rigids_compute_dihedral_angle(vecs_from_tensor(chis_atom_pos[i, :, :, 0, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 1, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 2, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 3, :])) + all_chi_angles.append(template_chi_angles) + chi_angles = self.stack(all_chi_angles) + return chi_angles, chis_mask + + +import numpy as np +from scipy.special import softmax + +def compute_confidence(predicted_lddt_logits, return_lddt=False): + """compute confidence""" + + num_bins = predicted_lddt_logits.shape[-1] + bin_width = 1 / num_bins + start_n = bin_width / 2 + plddt = compute_plddt(predicted_lddt_logits, start_n, bin_width) + confidence = np.mean(plddt) + if return_lddt: + return confidence, plddt + + return confidence + + +def compute_plddt(logits, start_n, bin_width): + """Computes per-residue pLDDT from logits. + + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + + Returns: + plddt: [num_res] per-residue pLDDT. + """ + bin_centers = np.arange(start=start_n, stop=1.0, step=bin_width) + probs = softmax(logits, axis=-1) + predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) + return predicted_lddt_ca * 100 \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/config/data-infer.yaml b/MindSPONGE/applications/research/Grasp/config/data-infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fcf35ab1649b163fce77bd9fc9e215dbfb3d54d6 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/config/data-infer.yaml @@ -0,0 +1,88 @@ +block_deletion: + msa_fraction_per_block: 0.3 + num_blocks: 5 + randomize_num_blocks: True +common: + random_recycle: True + distillation: False + replace_proportion: 0.0 + masked_msa: + use_masked_msa: True + profile_prob: 0.1 + same_prob: 0.1 + uniform_prob: 0.1 + max_extra_msa: 2048 + msa_cluster_features: True + num_recycle: 4 + reduce_msa_clusters_by_max_templates: True + resample_msa_in_recycling: True + use_templates: True + template_features: + - template_all_atom_positions + - template_sum_probs + - template_aatype + - template_all_atom_masks + - template_domain_names + unsupervised_features: + - aatype + - residue_index + - sequence + - msa + - domain_name + - num_alignments + - seq_length + - between_segment_residues + - deletion_matrix + - template_all_atom_positions + - template_sum_probs + - template_aatype + - template_all_atom_masks + - template_domain_names + supervised_features: + - all_atom_positions + - all_atom_mask + - atom14_atom_exists + - atom14_gt_exists + - atom14_gt_positions + - residx_atom14_to_atom37 + - residx_atom37_to_atom14 + - atom37_atom_exists + - atom14_alt_gt_positions + - atom14_alt_gt_exists + - atom14_atom_is_ambiguous + - rigidgroups_gt_frames + - rigidgroups_gt_exists + - rigidgroups_group_exists + - rigidgroups_group_is_ambiguous + - rigidgroups_alt_gt_frames + - backbone_affine_tensor + - torsion_angles_sin_cos + - alt_torsion_angles_sin_co + - torsion_angles_mask + - pseudo_beta + - pseudo_beta_mask + - chi_mask + - backbone_affine_mask + + +eval: + crop_size: 256 + fixed_size: True + masked_msa_replace_fraction: 0.15 + max_msa_clusters: 512 + max_templates: 4 + num_ensemble: 1 + subsample_templates: True + keep_extra: True + +database_search: + hhsearch_binary_path: None + kalign_binary_path: None + pdb70_database_path: None + mmcif_dir: None + obsolete_pdbs_path: None + max_template_date: "2100-01-01" + mmseqs_binary: None + uniref30_path: None + database_envdb_dir: None + a3m_result_path: "./a3m_result/" diff --git a/MindSPONGE/applications/research/Grasp/config/model-infer.yaml b/MindSPONGE/applications/research/Grasp/config/model-infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f04b688cefc0bc338ef770a1664181968b12dbad --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/config/model-infer.yaml @@ -0,0 +1,845 @@ +is_training: False +msa_channel: 256 +pair_channel: 128 +extra_msa_channel: 64 +max_relative_feature: 32 +recycle_features: True +recycle_pos: True +seq_channel: 384 +prev_pos: + min_bin: 3.25 + max_bin: 20.75 + num_bins: 15 +common: + target_feat_dim: 21 + msa_feat_dim: 49 + dgram_dim: 15 + pair_in_dim: 65 + msa_first_row_dim: 256 + prev_pair_dim: 128 + extra_msa_dim: 25 + template_feat_dim: 57 +template: + enabled: True + embed_torsion_angles: True + use_template_unit_vector: True + attention: + gating: False + key_dim: 64 + num_head: 4 + value_dim: 64 + dgram_features: + min_bin: 3.25 + max_bin: 50.75 + num_bins: 39 + template_pair_stack: + num_block: 2 + triangle_attention_starting_node: + dropout_rate: 0.25 + gating: True + key_dim: 64 + num_head: 4 + orientation: 'per_row' + shared_dropout: True + value_dim: 64 + triangle_attention_ending_node: + dropout_rate: 0.25 + gating: True + key_dim: 64 + num_head: 4 + orientation: 'per_column' + shared_dropout: True + value_dim: 64 + triangle_multiplication_outgoing: + dropout_rate: 0.25 + equation: 'ikc,jkc->ijc' + num_intermediate_channel: 64 + orientation: 'per_row' + shared_dropout: True + triangle_multiplication_incoming: + dropout_rate: 0.25 + equation: 'kjc,kic->ijc' + num_intermediate_channel: 64 + orientation: 'per_row' + shared_dropout: True + pair_transition: + dropout_rate: 0.0 + num_intermediate_factor: 2 + orientation: 'per_row' + shared_dropout: True +evoformer: + msa_stack_num: 48 + extra_msa_stack_num: 4 + msa_row_attention_with_pair_bias: + dropout_rate: 0.15 # 0.15 + gating: True + num_head: 8 + orientation: 'per_row' + shared_dropout: True + msa_column_attention: + dropout_rate: 0.0 + gating: True + num_head: 8 + orientation: 'per_column' + shared_dropout: True + msa_transition: + dropout_rate: 0.0 + num_intermediate_factor: 4 + orientation: 'per_row' + shared_dropout: True + outer_product_mean: + chunk_size: 128 + dropout_rate: 0.0 + num_outer_channel: 32 + orientation: 'per_row' + shared_dropout: True + triangle_attention_starting_node: + dropout_rate: 0.25 # 0.25 + gating: True + num_head: 4 + orientation: 'per_row' + shared_dropout: True + triangle_attention_ending_node: + dropout_rate: 0.25 # 0.25 + gating: True + num_head: 4 + orientation: 'per_column' + shared_dropout: True + triangle_multiplication_outgoing: + dropout_rate: 0.25 # 0.25 + equation: 'ikc,jkc->ijc' + num_intermediate_channel: 128 + orientation: 'per_row' + shared_dropout: True + triangle_multiplication_incoming: + dropout_rate: 0.25 # 0.25 + equation: 'kjc,kic->ijc' + num_intermediate_channel: 128 + orientation: 'per_row' + shared_dropout: True + pair_transition: + dropout_rate: 0.0 + num_intermediate_factor: 4 + orientation: 'per_row' + shared_dropout: True +structure_module: + num_layer: 8 + fape: + clamp_distance: 10.0 + clamp_type: 'relu' + loss_unit_distance: 10.0 + angle_norm_weight: 0.01 + chi_weight: 0.5 + clash_overlap_tolerance: 1.5 + compute_in_graph_metrics: True + dropout: 0.1 + num_channel: 384 + num_head: 12 + num_layer_in_transition: 3 + num_point_qk: 4 + num_point_v: 8 + num_scalar_qk: 16 + num_scalar_v: 16 + position_scale: 20.0 + sidechain: + atom_clamp_distance: 10.0 + num_channel: 128 + num_residual_block: 2 + weight_frac: 0.5 + length_scale: 10. + structural_violation_loss_weight: 1.0 + violation_tolerance_factor: 12.0 + weight: 1.0 +slice: + seq_248: + template_embedding: 4 + template_pair_stack: + triangle_attention_starting_node: 4 + triangle_attention_ending_node: 4 + pair_transition: 4 + extra_msa_stack: + msa_transition: 4 + msa_row_attention_with_pair_bias: 4 + msa_column_global_attention: 4 + outer_product_mean: 4 + triangle_attention_starting_node: 4 + triangle_attention_ending_node: 4 + pair_transition: 4 + msa_stack: + msa_transition: 4 + msa_row_attention_with_pair_bias: 4 + msa_column_attention: 4 + outer_product_mean: 4 + triangle_attention_starting_node: 4 + triangle_attention_ending_node: 4 + pair_transition: 4 + seq_256: + template_embedding: 2 + template_pair_stack: + triangle_attention_starting_node: 2 + triangle_attention_ending_node: 2 + pair_transition: 2 + extra_msa_stack: + msa_transition: 2 + msa_row_attention_with_pair_bias: 4 + msa_column_global_attention: 2 + outer_product_mean: 2 + triangle_attention_starting_node: 2 + triangle_attention_ending_node: 2 + pair_transition: 2 + msa_stack: + msa_transition: 2 + msa_row_attention_with_pair_bias: 2 + msa_column_attention: 2 + outer_product_mean: 2 + triangle_attention_starting_node: 2 + triangle_attention_ending_node: 2 + pair_transition: 2 + seq_512: + template_embedding: 8 + template_pair_stack: + triangle_attention_starting_node: 8 + triangle_attention_ending_node: 8 + pair_transition: 8 + extra_msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 64 + msa_column_global_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + seq_672: + template_embedding: 8 + template_pair_stack: + triangle_attention_starting_node: 8 + triangle_attention_ending_node: 8 + pair_transition: 8 + extra_msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 128 + msa_column_global_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + seq_768: + template_embedding: 8 + template_pair_stack: + triangle_attention_starting_node: 8 + triangle_attention_ending_node: 8 + pair_transition: 8 + extra_msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 128 + msa_column_global_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + seq_1024: + template_embedding: 16 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 16 # seq len + triangle_attention_ending_node: 16 # seq len + pair_transition: 16 # seq len + extra_msa_stack: + msa_transition: 1 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 16 # seq len + outer_product_mean: 1 # seq len + triangle_attention_starting_node: 16 # seq len + triangle_attention_ending_node: 16 # seq len + pair_transition: 1 # seq len + msa_stack: + msa_transition: 1 + msa_row_attention_with_pair_bias: 16 + msa_column_attention: 16 + outer_product_mean: 1 + triangle_attention_starting_node: 16 + triangle_attention_ending_node: 16 + pair_transition: 1 + seq_1280: + template_embedding: 16 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 16 # seq len + triangle_attention_ending_node: 16 # seq len + pair_transition: 16 # seq len + extra_msa_stack: + msa_transition: 1 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 16 # seq len + outer_product_mean: 1 # seq len + triangle_attention_starting_node: 16 # seq len + triangle_attention_ending_node: 16 # seq len + pair_transition: 1 # seq len + msa_stack: + msa_transition: 1 + msa_row_attention_with_pair_bias: 16 + msa_column_attention: 16 + outer_product_mean: 1 + triangle_attention_starting_node: 16 + triangle_attention_ending_node: 16 + pair_transition: 1 + # template_embedding: 8 # seq len * seq len + # template_pair_stack: + # triangle_attention_starting_node: 32 # seq len + # triangle_attention_ending_node: 32 # seq len + # pair_transition: 8 # seq len + # extra_msa_stack: + # msa_transition: 0 # 5120 + # msa_row_attention_with_pair_bias: 128 # 5120 + # msa_column_global_attention: 8 # seq len + # outer_product_mean: 0 # seq len + # triangle_attention_starting_node: 8 # seq len + # triangle_attention_ending_node: 8 # seq len + # pair_transition: 0 # seq len + # msa_stack: + # msa_transition: 0 + # msa_row_attention_with_pair_bias: 8 + # msa_column_attention: 8 + # outer_product_mean: 0 + # triangle_attention_starting_node: 8 + # triangle_attention_ending_node: 8 + # pair_transition: 0 + seq_1408: + template_embedding: 16 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 16 # seq len + triangle_attention_ending_node: 16 # seq len + pair_transition: 16 # seq len + extra_msa_stack: + msa_transition: 1 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 16 # seq len + outer_product_mean: 1 # seq len + triangle_attention_starting_node: 16 # seq len + triangle_attention_ending_node: 16 # seq len + pair_transition: 1 # seq len + msa_stack: + msa_transition: 1 + msa_row_attention_with_pair_bias: 16 + msa_column_attention: 16 + outer_product_mean: 1 + triangle_attention_starting_node: 16 + triangle_attention_ending_node: 16 + pair_transition: 1 + seq_1664: + template_embedding: 16 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 16 # seq len + triangle_attention_ending_node: 16 # seq len + pair_transition: 16 # seq len + extra_msa_stack: + msa_transition: 2 # 5120 + msa_row_attention_with_pair_bias: 256 # 5120 + msa_column_global_attention: 32 # seq len + outer_product_mean: 2 # seq len + triangle_attention_starting_node: 32 # seq len + triangle_attention_ending_node: 32 # seq len + pair_transition: 2 # seq len + msa_stack: + msa_transition: 2 + msa_row_attention_with_pair_bias: 32 + msa_column_attention: 32 + outer_product_mean: 2 + triangle_attention_starting_node: 32 + triangle_attention_ending_node: 32 + pair_transition: 2 + seq_1536: + template_embedding: 16 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 32 # seq len + triangle_attention_ending_node: 32 # seq len + pair_transition: 8 # seq len + extra_msa_stack: + msa_transition: 8 # 5120 + msa_row_attention_with_pair_bias: 256 # 5120 + msa_column_global_attention: 32 # seq len + outer_product_mean: 8 # seq len + triangle_attention_starting_node: 32 # seq len + triangle_attention_ending_node: 32 # seq len + pair_transition: 8 # seq len + msa_stack: + msa_transition: 8 + msa_row_attention_with_pair_bias: 32 + msa_column_attention: 32 + outer_product_mean: 8 + triangle_attention_starting_node: 32 + triangle_attention_ending_node: 32 + pair_transition: 8 + seq_1792: + template_embedding: 64 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 64 # seq len + triangle_attention_ending_node: 64 # seq len + pair_transition: 8 # seq len + extra_msa_stack: + msa_transition: 8 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 64 # seq len + outer_product_mean: 8 # seq len + triangle_attention_starting_node: 64 # seq len + triangle_attention_ending_node: 64 # seq len + pair_transition: 8 # seq len + msa_stack: + msa_transition: 8 + msa_row_attention_with_pair_bias: 64 + msa_column_attention: 64 + outer_product_mean: 8 + triangle_attention_starting_node: 64 + triangle_attention_ending_node: 64 + pair_transition: 8 + seq_2048: + template_embedding: 128 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 128 # seq len + extra_msa_stack: + msa_transition: 128 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 128 # seq len + outer_product_mean: 128 # seq len + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 128 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 128 + msa_column_attention: 128 + outer_product_mean: 128 + triangle_attention_starting_node: 128 + triangle_attention_ending_node: 128 + pair_transition: 128 + seq_2304: + template_embedding: 128 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 128 # seq len + extra_msa_stack: + msa_transition: 128 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 256 # seq len + outer_product_mean: 128 # seq len + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 128 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 256 + msa_column_attention: 256 + outer_product_mean: 256 + triangle_attention_starting_node: 256 + triangle_attention_ending_node: 256 + pair_transition: 128 + seq_3072: + template_embedding: 256 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 512 # seq len + triangle_attention_ending_node: 512 # seq len + pair_transition: 256 # seq len + extra_msa_stack: + msa_transition: 256 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 512 # seq len + outer_product_mean: 256 # seq len + triangle_attention_starting_node: 512 # seq len + triangle_attention_ending_node: 512 # seq len + pair_transition: 256 # seq len + msa_stack: + msa_transition: 256 + msa_row_attention_with_pair_bias: 512 + msa_column_attention: 512 + outer_product_mean: 512 + triangle_attention_starting_node: 512 + triangle_attention_ending_node: 512 + pair_transition: 256 + seq_4096: + template_embedding: 128 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 128 # seq len + extra_msa_stack: + msa_transition: 128 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 128 # seq len + outer_product_mean: 128 # seq len + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 128 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 128 + msa_column_attention: 128 + outer_product_mean: 128 + triangle_attention_starting_node: 128 + triangle_attention_ending_node: 128 + pair_transition: 128 + seq_6144: + template_embedding: 32 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 16 # seq len + extra_msa_stack: + msa_transition: 16 # 5120 + msa_row_attention_with_pair_bias: 64 # 5120 + msa_column_global_attention: 64 # seq len + outer_product_mean: 32 # seq len + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 16 # seq len + msa_stack: + msa_transition: 64 + msa_row_attention_with_pair_bias: 64 + msa_column_attention: 64 + outer_product_mean: 64 + triangle_attention_starting_node: 128 + triangle_attention_ending_node: 128 + pair_transition: 64 + # seq_7168: + # template_embedding: 128 # seq len * seq len + # template_pair_stack: + # triangle_attention_starting_node: 128 # seq len + # triangle_attention_ending_node: 128 # seq len + # pair_transition: 128 # seq len + # extra_msa_stack: + # msa_transition: 128 # 5120 + # msa_row_attention_with_pair_bias: 512 # 5120 + # msa_column_global_attention: 128 # seq len + # outer_product_mean: 128 # seq len + # triangle_attention_starting_node: 128 # seq len + # triangle_attention_ending_node: 128 # seq len + # pair_transition: 128 # seq len + # msa_stack: + # msa_transition: 128 + # msa_row_attention_with_pair_bias: 128 + # msa_column_attention: 128 + # outer_product_mean: 128 + # triangle_attention_starting_node: 128 + # triangle_attention_ending_node: 128 + # pair_transition: 128 + seq_6912: + template_embedding: 32 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 16 # seq len + extra_msa_stack: + msa_transition: 16 # 5120 + msa_row_attention_with_pair_bias: 64 # 5120 + msa_column_global_attention: 64 # seq len + outer_product_mean: 32 # seq len + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 16 # seq len + msa_stack: + msa_transition: 64 + msa_row_attention_with_pair_bias: 64 + msa_column_attention: 64 + outer_product_mean: 64 + triangle_attention_starting_node: 128 + triangle_attention_ending_node: 128 + pair_transition: 64 + seq_7680: + template_embedding: 64 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 32 # seq len + extra_msa_stack: + msa_transition: 32 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 128 # seq len + outer_product_mean: 64 # seq len + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 32 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 128 + msa_column_attention: 128 + outer_product_mean: 128 + triangle_attention_starting_node: 256 + triangle_attention_ending_node: 256 + pair_transition: 128 + seq_7552: + template_embedding: 32 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 16 # seq len + extra_msa_stack: + msa_transition: 16 # 5120 + msa_row_attention_with_pair_bias: 64 # 5120 + msa_column_global_attention: 64 # seq len + outer_product_mean: 32 # seq len + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 16 # seq len + msa_stack: + msa_transition: 64 + msa_row_attention_with_pair_bias: 64 + msa_column_attention: 64 + outer_product_mean: 64 + triangle_attention_starting_node: 128 + triangle_attention_ending_node: 128 + pair_transition: 64 + seq_8064: + template_embedding: 32 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 16 # seq len + extra_msa_stack: + msa_transition: 16 # 5120 + msa_row_attention_with_pair_bias: 64 # 5120 + msa_column_global_attention: 64 # seq len + outer_product_mean: 32 # seq len + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 16 # seq len + msa_stack: + msa_transition: 64 + msa_row_attention_with_pair_bias: 64 + msa_column_attention: 64 + outer_product_mean: 64 + triangle_attention_starting_node: 128 + triangle_attention_ending_node: 128 + pair_transition: 64 + seq_7936: + template_embedding: 64 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 32 # seq len + extra_msa_stack: + msa_transition: 32 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 128 # seq len + outer_product_mean: 64 # seq len + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 32 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 128 + msa_column_attention: 128 + outer_product_mean: 128 + triangle_attention_starting_node: 256 + triangle_attention_ending_node: 256 + pair_transition: 128 + seq_8192: + template_embedding: 64 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 32 # seq len + extra_msa_stack: + msa_transition: 32 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 128 # seq len + outer_product_mean: 64 # seq len + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 32 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 128 + msa_column_attention: 128 + outer_product_mean: 128 + triangle_attention_starting_node: 256 + triangle_attention_ending_node: 256 + pair_transition: 128 + # seq_8000: + # template_embedding: 80 # seq len * seq len + # template_pair_stack: + # triangle_attention_starting_node: 250 # seq len + # triangle_attention_ending_node: 250 # seq len + # pair_transition: 40 # seq len + # extra_msa_stack: + # msa_transition: 40 # 5120 + # msa_row_attention_with_pair_bias: 125 # 5120 + # msa_column_global_attention: 125 # seq len + # outer_product_mean: 80 # seq len + # triangle_attention_starting_node: 250 # seq len + # triangle_attention_ending_node: 250 # seq len + # pair_transition: 40 # seq len + # msa_stack: + # msa_transition: 125 + # msa_row_attention_with_pair_bias: 125 + # msa_column_attention: 125 + # outer_product_mean: 125 + # triangle_attention_starting_node: 250 + # triangle_attention_ending_node: 250 + # pair_transition: 125 + # seq_8192: + # template_embedding: 128 # + # template_pair_stack: + # triangle_attention_starting_node: 128 # seq len + # triangle_attention_ending_node: 128 # seq len + # pair_transition: 128 # seq len + # extra_msa_stack: + # msa_transition: 128 # 5120 + # msa_row_attention_with_pair_bias: 512 # 5120 + # msa_column_global_attention: 128 # seq len + # outer_product_mean: 128 # seq len + # triangle_attention_starting_node: 128 # seq len + # triangle_attention_ending_node: 128 # seq len + # pair_transition: 128 # seq len + # msa_stack: + # msa_transition: 128 + # msa_row_attention_with_pair_bias: 128 + # msa_column_attention: 128 + # outer_product_mean: 128 + # triangle_attention_starting_node: 128 + # triangle_attention_ending_node: 128 + # pair_transition: 128 + # seq_8192: + # template_embedding: 32 # seq len * seq len + # template_pair_stack: + # triangle_attention_starting_node: 128 # seq len + # triangle_attention_ending_node: 128 # seq len + # pair_transition: 16 # seq len + # extra_msa_stack: + # msa_transition: 16 # 5120 + # msa_row_attention_with_pair_bias: 64 # 5120 + # msa_column_global_attention: 64 # seq len + # outer_product_mean: 32 # seq len + # triangle_attention_starting_node: 128 # seq len + # triangle_attention_ending_node: 128 # seq len + # pair_transition: 16 # seq len + # msa_stack: + # msa_transition: 64 + # msa_row_attention_with_pair_bias: 64 + # msa_column_attention: 64 + # outer_product_mean: 64 + # triangle_attention_starting_node: 128 + # triangle_attention_ending_node: 128 + # pair_transition: 64 + # seq_6144: + # template_embedding: 256 # seq len * seq len + # template_pair_stack: + # triangle_attention_starting_node: 512 # seq len + # triangle_attention_ending_node: 512 # seq len + # pair_transition: 256 # seq len + # extra_msa_stack: + # msa_transition: 256 # 5120 + # msa_row_attention_with_pair_bias: 512 # 5120 + # msa_column_global_attention: 512 # seq len + # outer_product_mean: 256 # seq len + # triangle_attention_starting_node: 512 # seq len + # triangle_attention_ending_node: 512 # seq len + # pair_transition: 256 # seq len + # msa_stack: + # msa_transition: 256 + # msa_row_attention_with_pair_bias: 512 + # msa_column_attention: 512 + # outer_product_mean: 512 + # triangle_attention_starting_node: 512 + # triangle_attention_ending_node: 512 + # pair_transition: 256 + # seq_6144: + # template_embedding: 128 # seq len * seq len + # template_pair_stack: + # triangle_attention_starting_node: 128 # seq len + # triangle_attention_ending_node: 128 # seq len + # pair_transition: 128 # seq len + # extra_msa_stack: + # msa_transition: 128 # 5120 + # msa_row_attention_with_pair_bias: 512 # 5120 + # msa_column_global_attention: 128 # seq len + # outer_product_mean: 128 # seq len + # triangle_attention_starting_node: 128 # seq len + # triangle_attention_ending_node: 128 # seq len + # pair_transition: 128 # seq len + # msa_stack: + # msa_transition: 128 + # msa_row_attention_with_pair_bias: 128 + # msa_column_attention: 128 + # outer_product_mean: 128 + # triangle_attention_starting_node: 128 + # triangle_attention_ending_node: 128 + # pair_transition: 128 +heads: + resolution: 1 + predicted_lddt: + filter_by_resolution: True + max_resolution: 3.0 + min_resolution: 0.1 + num_bins: 50 + num_channels: 128 + weight: 0.01 + distogram: + first_break: 2.3125 + last_break: 21.6875 + num_bins: 64 + weight: 0.3 + masked_msa: + num_output: 22 + weight: 2.0 + predicted_aligned_error: + max_error_bin: 31.0 + num_bins: 64 + num_channels: 128 + filter_by_resolution: True + min_resolution: 0.1 + max_resolution: 3.0 + weight: 0.0 + experimentally_resolved: + filter_by_resolution: True + max_resolution: 3.0 + min_resolution: 0.1 + weight: 0.01 + structure_module: + fape: + clamp_distance: 10.0 + loss_unit_distance: 10.0 + angle_norm_weight: 0.01 + chi_weight: 0.5 + clash_overlap_tolerance: 1.5 + sidechain: + atom_clamp_distance: 10.0 + weight_frac: 0.5 + length_scale: 10.0 + structural_violation_loss_weight: 1.0 + violation_tolerance_factor: 12.0 +multimer: + embeddings_and_evoformer: + num_msa: 508 + num_extra_msa: 2048 + masked_msa: + profile_prob: 0.1 + replace_fraction: 0.15 + same_prob: 0.1 + uniform_prob: 0.1 + use_chain_relative: True + max_relative_chain: 2 + pair_in_dim: 73 \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/config/multimer-data.yaml b/MindSPONGE/applications/research/Grasp/config/multimer-data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3ca2af115709a464d56ecd771ed88a5c018ddeb --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/config/multimer-data.yaml @@ -0,0 +1,77 @@ +common: + crop_size: 256 + max_msa_entry: 33554432 # 1 << 25 + max_msa_clusters: 256 + max_extra_msa: 1024 + max_templates: 4 + num_ensembles: 1 + num_recycle: 3 + profile_prob: 0.1 + same_prob: 0.1 + uniform_prob: 0.1 + replace_fraction: 0.15 + replace_proportion: 0.0 + spatial_crop_prob: 0.5 + ca_ca_threshold: 10.0 + biased_msa_by_chain: True + distillation: False + use_templates: True + use_masked_msa: True + share_mask: True + msa_cluster_features: True + subsample_templates: True + use_template_torsion_angles: True + reduce_msa_clusters_by_max_templates: True + template_features: + - template_all_atom_positions + - template_sum_probs + - template_aatype + - template_all_atom_mask + unsupervised_features: + - aatype + - residue_index + - msa + - msa_chains + - num_alignments + - seq_length + - between_segment_residues + - deletion_matrix + - crop_and_fix_size_seed + recycling_features: + - msa_chains + - msa_mask + - msa_row_mask + - bert_mask + - true_msa + - msa_feat + - extra_msa_deletion_value + - extra_msa_has_deletion + - extra_msa + - extra_msa_mask + - extra_msa_row_mask + - is_distillation + multimer_features: + - assembly_num_chains + - asym_id + - sym_id + - num_sym + - entity_id + - asym_len + - cluster_bias_mask + supervised_features: + - all_atom_mask + - all_atom_positions + - resolution + - use_clamped_fape + - is_distillation + + +eval: + crop_size: 256 + fixed_size: True + masked_msa_replace_fraction: 0.15 + max_msa_clusters: 512 + max_templates: 4 + num_ensemble: 1 + subsample_templates: True + keep_extra: True \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/config/multimer-model.yaml b/MindSPONGE/applications/research/Grasp/config/multimer-model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f33573b07c6169252d6e8ad9118433b23be0dfeb --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/config/multimer-model.yaml @@ -0,0 +1,464 @@ +is_training: False +msa_channel: 256 +pair_channel: 128 +extra_msa_channel: 64 +max_relative_feature: 32 +recycle_features: True +recycle_pos: True +seq_channel: 384 +GPU: + lr_max: 0.0003 #1e-3 + lr_min: 0.0001 #1e-4 + warmup_steps: 1000 + start_step: 0 + lr_decay_steps: 75000 +prev_pos: + min_bin: 3.25 + max_bin: 20.75 + num_bins: 15 +common: + target_feat_dim: 21 + msa_feat_dim: 49 + dgram_dim: 15 + pair_in_dim: 65 + msa_first_row_dim: 256 + prev_pair_dim: 128 + extra_msa_dim: 25 + template_feat_dim: 57 +template: + enabled: True + embed_torsion_angles: True + use_template_unit_vector: True + attention: + gating: False + key_dim: 64 + num_head: 4 + value_dim: 64 + dgram_features: + min_bin: 3.25 + max_bin: 50.75 + num_bins: 39 + template_pair_stack: + num_block: 2 + triangle_attention_starting_node: + dropout_rate: 0.25 + gating: True + key_dim: 64 + num_head: 4 + orientation: 'per_row' + shared_dropout: True + value_dim: 64 + triangle_attention_ending_node: + dropout_rate: 0.25 + gating: True + key_dim: 64 + num_head: 4 + orientation: 'per_column' + shared_dropout: True + value_dim: 64 + triangle_multiplication_outgoing: + dropout_rate: 0.25 + equation: 'ikc,jkc->ijc' + num_intermediate_channel: 64 + orientation: 'per_row' + shared_dropout: True + triangle_multiplication_incoming: + dropout_rate: 0.25 + equation: 'kjc,kic->ijc' + num_intermediate_channel: 64 + orientation: 'per_row' + shared_dropout: True + pair_transition: + dropout_rate: 0.0 + num_intermediate_factor: 2 + orientation: 'per_row' + shared_dropout: True +evoformer: + msa_stack_num: 48 + extra_msa_stack_num: 4 + msa_row_attention_with_pair_bias: + dropout_rate: 0.15 # 0.15 + gating: True + num_head: 8 + orientation: 'per_row' + shared_dropout: True + msa_column_attention: + dropout_rate: 0.0 + gating: True + num_head: 8 + orientation: 'per_column' + shared_dropout: True + msa_transition: + dropout_rate: 0.0 + num_intermediate_factor: 4 + orientation: 'per_row' + shared_dropout: True + outer_product_mean: + chunk_size: 128 + dropout_rate: 0.0 + num_outer_channel: 32 + orientation: 'per_row' + shared_dropout: True + triangle_attention_starting_node: + dropout_rate: 0.25 # 0.25 + gating: True + num_head: 4 + orientation: 'per_row' + shared_dropout: True + triangle_attention_ending_node: + dropout_rate: 0.25 # 0.25 + gating: True + num_head: 4 + orientation: 'per_column' + shared_dropout: True + triangle_multiplication_outgoing: + dropout_rate: 0.25 # 0.25 + equation: 'ikc,jkc->ijc' + num_intermediate_channel: 128 + orientation: 'per_row' + shared_dropout: True + triangle_multiplication_incoming: + dropout_rate: 0.25 # 0.25 + equation: 'kjc,kic->ijc' + num_intermediate_channel: 128 + orientation: 'per_row' + shared_dropout: True + pair_transition: + dropout_rate: 0.0 + num_intermediate_factor: 4 + orientation: 'per_row' + shared_dropout: True +structure_module: + num_layer: 8 + fape: + clamp_distance: 10.0 + clamp_type: 'relu' + loss_unit_distance: 10.0 + angle_norm_weight: 0.01 + chi_weight: 0.5 + clash_overlap_tolerance: 1.5 + compute_in_graph_metrics: True + dropout: 0.1 + num_channel: 384 + num_head: 12 + num_layer_in_transition: 3 + num_point_qk: 4 + num_point_v: 8 + num_scalar_qk: 16 + num_scalar_v: 16 + position_scale: 20.0 + sidechain: + atom_clamp_distance: 10.0 + num_channel: 128 + num_residual_block: 2 + weight_frac: 0.5 + length_scale: 10. + structural_violation_loss_weight: 1.0 + violation_tolerance_factor: 12.0 + weight: 1.0 +slice: + seq_256: + template_embedding: 0 + template_pair_stack: + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + extra_msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 4 + msa_column_global_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + seq_384: + template_embedding: 0 + template_pair_stack: + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + extra_msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 128 + msa_column_global_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + seq_512: + template_embedding: 0 + template_pair_stack: + triangle_attention_starting_node: 4 + triangle_attention_ending_node: 4 + pair_transition: 0 + extra_msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 64 + msa_column_global_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + seq_768: + template_embedding: 8 + template_pair_stack: + triangle_attention_starting_node: 8 + triangle_attention_ending_node: 8 + pair_transition: 8 + extra_msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 128 + msa_column_global_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + seq_1024: + template_embedding: 8 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 8 # seq len + triangle_attention_ending_node: 8 # seq len + pair_transition: 8 # seq len + extra_msa_stack: + msa_transition: 0 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 8 # seq len + outer_product_mean: 0 # seq len + triangle_attention_starting_node: 8 # seq len + triangle_attention_ending_node: 8 # seq len + pair_transition: 0 # seq len + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 8 + msa_column_attention: 8 + outer_product_mean: 0 + triangle_attention_starting_node: 8 + triangle_attention_ending_node: 8 + pair_transition: 0 + seq_1280: + template_embedding: 8 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 32 # seq len + triangle_attention_ending_node: 32 # seq len + pair_transition: 8 # seq len + extra_msa_stack: + msa_transition: 0 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 8 # seq len + outer_product_mean: 0 # seq len + triangle_attention_starting_node: 8 # seq len + triangle_attention_ending_node: 8 # seq len + pair_transition: 0 # seq len + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 8 + msa_column_attention: 8 + outer_product_mean: 0 + triangle_attention_starting_node: 8 + triangle_attention_ending_node: 8 + pair_transition: 0 + seq_1536: + template_embedding: 16 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 32 # seq len + triangle_attention_ending_node: 32 # seq len + pair_transition: 8 # seq len + extra_msa_stack: + msa_transition: 8 # 5120 + msa_row_attention_with_pair_bias: 256 # 5120 + msa_column_global_attention: 32 # seq len + outer_product_mean: 8 # seq len + triangle_attention_starting_node: 32 # seq len + triangle_attention_ending_node: 32 # seq len + pair_transition: 8 # seq len + msa_stack: + msa_transition: 8 + msa_row_attention_with_pair_bias: 32 + msa_column_attention: 32 + outer_product_mean: 8 + triangle_attention_starting_node: 32 + triangle_attention_ending_node: 32 + pair_transition: 8 + seq_1792: + template_embedding: 64 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 64 # seq len + triangle_attention_ending_node: 64 # seq len + pair_transition: 8 # seq len + extra_msa_stack: + msa_transition: 8 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 64 # seq len + outer_product_mean: 8 # seq len + triangle_attention_starting_node: 64 # seq len + triangle_attention_ending_node: 64 # seq len + pair_transition: 8 # seq len + msa_stack: + msa_transition: 8 + msa_row_attention_with_pair_bias: 64 + msa_column_attention: 64 + outer_product_mean: 8 + triangle_attention_starting_node: 64 + triangle_attention_ending_node: 64 + pair_transition: 8 + seq_2048: + template_embedding: 128 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 128 # seq len + extra_msa_stack: + msa_transition: 128 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 128 # seq len + outer_product_mean: 128 # seq len + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 128 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 128 + msa_column_attention: 128 + outer_product_mean: 128 + triangle_attention_starting_node: 128 + triangle_attention_ending_node: 128 + pair_transition: 128 + seq_2304: + template_embedding: 128 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 128 # seq len + extra_msa_stack: + msa_transition: 128 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 256 # seq len + outer_product_mean: 128 # seq len + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 128 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 256 + msa_column_attention: 256 + outer_product_mean: 256 + triangle_attention_starting_node: 256 + triangle_attention_ending_node: 256 + pair_transition: 128 + seq_3072: + template_embedding: 128 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 256 # seq len + triangle_attention_ending_node: 256 # seq len + pair_transition: 128 # seq len + extra_msa_stack: + msa_transition: 0 # 5120 + msa_row_attention_with_pair_bias: 128 # 5120 + msa_column_global_attention: 8 # seq len + outer_product_mean: 0 # seq len + triangle_attention_starting_node: 8 # seq len + triangle_attention_ending_node: 8 # seq len + pair_transition: 0 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 256 + msa_column_attention: 256 + outer_product_mean: 256 + triangle_attention_starting_node: 256 + triangle_attention_ending_node: 256 + pair_transition: 128 + seq_4096: + template_embedding: 128 # seq len * seq len + template_pair_stack: + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 128 # seq len + extra_msa_stack: + msa_transition: 128 # 5120 + msa_row_attention_with_pair_bias: 512 # 5120 + msa_column_global_attention: 128 # seq len + outer_product_mean: 128 # seq len + triangle_attention_starting_node: 128 # seq len + triangle_attention_ending_node: 128 # seq len + pair_transition: 128 # seq len + msa_stack: + msa_transition: 128 + msa_row_attention_with_pair_bias: 128 + msa_column_attention: 128 + outer_product_mean: 128 + triangle_attention_starting_node: 128 + triangle_attention_ending_node: 128 + pair_transition: 128 +heads: + resolution: 1 + predicted_lddt: + filter_by_resolution: True + max_resolution: 3.0 + min_resolution: 0.1 + num_bins: 50 + num_channels: 128 + weight: 0.01 + distogram: + first_break: 2.3125 + last_break: 21.6875 + num_bins: 64 + weight: 0.3 + masked_msa: + num_output: 22 + weight: 2.0 + predicted_aligned_error: + max_error_bin: 31.0 + num_bins: 64 + num_channels: 128 + filter_by_resolution: True + min_resolution: 0.1 + max_resolution: 3.0 + weight: 0.0 + experimentally_resolved: + filter_by_resolution: True + max_resolution: 3.0 + min_resolution: 0.1 + weight: 0.01 +multimer: + embeddings_and_evoformer: + num_msa: 252 + masked_msa: + profile_prob: 0.1 + replace_fraction: 0.15 + same_prob: 0.1 + uniform_prob: 0.1 + use_chain_relative: True + max_relative_chain: 2 + pair_in_dim: 73 \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/data/__init__.py b/MindSPONGE/applications/research/Grasp/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41010cdb504f4d55f4356d75af0aaa82850729e1 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''init''' +from .preprocess import Feature, MultimerFeature +# from .protein_feature import RawFeatureGenerator +from .utils import get_crop_size, get_raw_feature +from .dataset import create_dataset, process_pdb, OUTPUT_LABEL_KEYS + diff --git a/MindSPONGE/applications/research/Grasp/data/dataset.py b/MindSPONGE/applications/research/Grasp/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ceb1c3cf487fde07c99beb0c16e5b06673314a3b --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/dataset.py @@ -0,0 +1,389 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train dataset""" +import datetime +import random +import os +import pickle +import time +import numpy as np +from mindspore import dataset as ds +from mindspore.communication import get_rank + +from mindsponge1.common.residue_constants import make_atom14_dists_bounds, order_restype_with_x +from mindsponge1.common.protein import from_pdb_string +from mindsponge1.common.utils import make_atom14_positions, get_aligned_seq +from mindsponge1.data.data_transform import pseudo_beta_fn, atom37_to_frames, atom37_to_torsion_angles +from .preprocess import Feature +from .multimer_pipeline import add_assembly_features, pair_and_merge, post_process +from .multimer_process import process_labels + + +OUTPUT_LABEL_KEYS = ['aatype_per_chain', 'all_atom_positions', 'all_atom_mask', 'atom14_atom_exists', + 'atom14_gt_exists', 'atom14_gt_positions', 'residx_atom14_to_atom37', + 'atom37_atom_exists_per_chain', 'atom14_alt_gt_positions', 'atom14_alt_gt_exists', + 'atom14_atom_is_ambiguous', 'rigidgroups_gt_frames', 'rigidgroups_gt_exists', + 'rigidgroups_alt_gt_frames', 'backbone_affine_tensor', 'torsion_angles_sin_cos', + 'pseudo_beta', 'pseudo_beta_mask', 'chi_mask', 'backbone_affine_mask', + 'chain_index'] +def create_dataset(pdb_path, pkl_path, paired_pkl_path, all_name_list, data_cfg, resolution_data, shuffle=False, + num_parallel_worker=4, hard_rate=0, high=25, + is_parallel=False, mixed_precision=False): + """create train dataset""" + + column_name = ['aatype', 'residue_index', 'template_aatype', 'template_all_atom_masks', + 'template_all_atom_positions', 'asym_id', 'sym_id', 'entity_id', 'seq_mask', 'msa_mask', + 'target_feat', 'msa_feat', 'extra_msa', 'extra_msa_deletion_value', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists', + "prev_pos", "prev_msa_first_row", "prev_pair", + "num_sym", "bert_mask", "true_msa", ] + \ + OUTPUT_LABEL_KEYS + \ + ["atomtype_radius", "restype_atom14_bond_lower_bound", "restype_atom14_bond_upper_bound", \ + "use_clamped_fape", "filter_by_solution", "prot_name_index"] + + dataset_generator = DatasetGenerator(pdb_path, pkl_path, paired_pkl_path, all_name_list, data_cfg, resolution_data, mixed_precision, hard_rate, high) + prefetch_size = 1 + print("prefetch_size", prefetch_size) + ds.config.set_prefetch_size(prefetch_size) + + if is_parallel: + rank_id = get_rank() % 8 + rank_size = 8 + train_dataset = ds.GeneratorDataset(source=dataset_generator, column_names=column_name, + num_parallel_workers=num_parallel_worker, shuffle=shuffle, + num_shards=rank_size, + shard_id=rank_id, max_rowsize=16) + else: + train_dataset = ds.GeneratorDataset(source=dataset_generator, column_names=column_name, + num_parallel_workers=num_parallel_worker, shuffle=shuffle, max_rowsize=16) + return train_dataset + + +class DatasetGenerator: + """dataset generator""" + def __init__(self, pdb_path, pkl_path, paired_pkl_path, all_name_list, data_cfg, resolution_data, mixed_precision, hard_rate, high=25): + self.t1 = time.time() + self.pdb_path = pdb_path + self.pkl_path = pkl_path + self.paired_pkl_path = paired_pkl_path + self.all_name_list = all_name_list + self.data_cfg = data_cfg + self.resolution_info = resolution_data + self.mixed_precision = mixed_precision + self.hard_rate = hard_rate + self.high = high + print("end dataset init") + + def _random_sample_chains(self, name_list, max_chains=32): + + np.random.shuffle(name_list) + + return name_list[:max_chains] + + def __getitem__(self, index): + # import time + # tm0 = time.time() + is_multimer = True + try: + name_list = self.all_name_list[index] + name_list = self._random_sample_chains(name_list) + input_arrays, prev_pos, prev_msa_first_row, prev_pair, \ + num_sym, bert_mask, true_msa, labels_arrays \ + = self._get_train_data(name_list, is_multimer) + except: + print('error for name', name_list) + # raise IOError + name_list = self.all_name_list[0] + name_list = self._random_sample_chains(name_list) + input_arrays, prev_pos, prev_msa_first_row, prev_pair, \ + num_sym, bert_mask, true_msa, labels_arrays \ + = self._get_train_data(name_list, is_multimer) + + prot_name_index = np.array([index]).astype(np.int32) + atomtype_radius = np.array( + [1.55, 1.7, 1.7, 1.7, 1.52, 1.7, 1.7, 1.7, 1.52, 1.52, 1.8, 1.7, 1.7, 1.7, 1.55, 1.55, + 1.52, 1.52, 1.8, 1.7, 1.7, 1.7, 1.7, 1.55, 1.55, 1.55, 1.52, 1.52, 1.7, 1.55, 1.55, + 1.52, 1.7, 1.7, 1.7, 1.55, 1.52]) + restype_atom14_bond_lower_bound, restype_atom14_bond_upper_bound, _ = \ + make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=12.0) + use_clamped_fape = np.random.binomial(1, 0.9, size=1) + filter_by_solution = self._get_solution_flag(name_list[0].split("_")[0]) + extra_feats = [atomtype_radius, restype_atom14_bond_lower_bound, + restype_atom14_bond_upper_bound, use_clamped_fape, filter_by_solution] + + dtype = np.float32 + if self.mixed_precision: + dtype = np.float16 + extra_feats = [array.astype(dtype) for array in extra_feats] + [prot_name_index] + + + all_feats = input_arrays + [prev_pos, prev_msa_first_row, prev_pair, num_sym, bert_mask, true_msa] + labels_arrays + extra_feats + + # print(name_list[0], len(name_list), time.time()-tm0) + return tuple(all_feats) + + def __len__(self): + return len(self.all_name_list) + + def _get_solution_flag(self, prot_name): + """get resolution data""" + if prot_name not in self.resolution_info: + return np.array(1.0).astype(np.float32) + resolution = float(self.resolution_info[prot_name]) + if resolution < 3: + return np.array(1.0).astype(np.float32) + return np.array(0.0).astype(np.float32) + + def _get_random_sampled_index(self, total_num, high=25): + need_num = min(np.random.randint(1, high+1), total_num) + sampled_index = random.sample(range(total_num), need_num) + return sampled_index + + + + + def _get_train_data(self, name_list, is_multimer=True): + """get train data""" + + def load_multi_data(name_list): + + prot_name = name_list[0].split("_")[0] + turn_hard = np.random.rand() < self.hard_rate + + paired_feature = None + if len(name_list) > 1 and os.path.exists(f"{self.paired_pkl_path}/{prot_name}.pkl"): + with open(f"{self.paired_pkl_path}/{prot_name}.pkl", "rb") as f: + paired_feature = pickle.load(f) + if turn_hard and len(paired_feature) > 0: + sampled_index = self._get_random_sampled_index(list(paired_feature.values())[0]['msa'].shape[0], self.high) + for k, v in paired_feature.items(): + for k1, v1 in v.items(): + if k1 in ['msa', 'deletion_matrix']: + paired_feature[k][k1] = v1[sampled_index] + + + all_seq_len = 0 + features_all = [] + sequences = [] + turn_hard_seq_index = {} + for name in name_list: + + features = {} + pkl_path_single = os.path.join(self.pkl_path, name + ".pkl") + + with open(pkl_path_single, "rb") as f: + raw_feature = pickle.load(f) + features['aatype']=np.nonzero(raw_feature['aatype'])[1].astype(np.int32) + seq_len = raw_feature["msa"].shape[1] + features["between_segment_residues"] = raw_feature["between_segment_residues"] + features["residue_index"] = raw_feature["residue_index"] + seq = raw_feature["sequence"][0].decode() + features["sequence"] = np.array(seq) + sequences.append(seq) + + features["msa"] = raw_feature["msa"] + features["deletion_matrix"] = raw_feature["deletion_matrix_int"] + if turn_hard: + if seq not in turn_hard_seq_index: + sampled_index = self._get_random_sampled_index(features["msa"].shape[0], self.high) + turn_hard_seq_index[seq] = sampled_index + else: + sampled_index = turn_hard_seq_index[seq] + features["msa"] = features["msa"][sampled_index] + features["deletion_matrix"] = features["deletion_matrix"][sampled_index] + features["num_alignments"] = np.array(features["msa"].shape[0]) + + if (not turn_hard) and (len(raw_feature["template_aatype"].shape) > 1): + features["template_aatype"] = np.argmax(raw_feature["template_aatype"], axis=-1) + features["template_all_atom_mask"] = raw_feature["template_all_atom_masks"] + features["template_all_atom_positions"] = raw_feature["template_all_atom_positions"] + else: + features["template_aatype"] = np.zeros((1, seq_len)).astype(np.int32) + features["template_all_atom_mask"] = np.zeros((1, seq_len, 37)).astype(np.float32) + features["template_all_atom_positions"] = np.zeros((1, seq_len, 37, 3)).astype(np.float32) + + + if paired_feature: + features["msa_all_seq"] = paired_feature[seq]["msa"] + features["deletion_matrix_all_seq"] = paired_feature[seq]["deletion_matrix"] + features["num_alignments_all_seq"] = np.array(features["msa_all_seq"].shape[0]) + all_seq_len += seq_len + + pdb_path_single = os.path.join(self.pdb_path, name + ".pdb") + with open(pdb_path_single, 'r') as f: + prot_pdb = from_pdb_string(f.read()) + aatype = prot_pdb.aatype + seq_len = len(aatype) + atom37_positions = prot_pdb.atom_positions.astype(np.float32) + atom37_mask = prot_pdb.atom_mask.astype(np.float32) + + features["seq_length"] = np.array(seq_len) + features["aatype_pdb"] = np.array(aatype) + features["all_atom_positions"] = atom37_positions + features["all_atom_mask"] = atom37_mask + + features_all.append(features) + + is_homomer = len(set(sequences)) == 1 and len(sequences) > 1 + # is_homomer = len(set(sequences)) == 1 + + if is_homomer and "msa_all_seq" not in features_all[0].keys(): + for features in features_all: + features["msa_all_seq"] = features["msa"] + features["deletion_matrix_all_seq"] = features["deletion_matrix"] + features["num_alignments_all_seq"] = np.array(features["msa_all_seq"].shape[0]) + + # print(f"\n\n\n=========================={name_list}") + # for i, features in enumerate(features_all): + # print(f"\n=========================={i}") + # for key, value in features.items(): + # print(key, value.shape, value.dtype) + + # print(len(name_list), prot_name, all_seq_len) + return features_all, all_seq_len + + + features, all_seq_len = load_multi_data(name_list) + + # if "msa_all_seq" not in feature and\ + # np.sum([feature["msa"].shape[0]==features[0]["msa"].shape[0] for feature in features]) < len(features): + # print(f"paired msa num not the same for prot ", name_list[0].split("_")[0]) + # paired_msa_num = np.min([feature["msa_all_seq"].shape[0] for feature in features]) + # for feature in features: + # feature["msa_all_seq"] = feature["msa_all_seq"][:1] + # feature["deletion_matrix_all_seq"] = feature["deletion_matrix_all_seq"][:1] + # feature["num_alignments_all_seq"] = np.array(1) + + + features = add_assembly_features(features) + # for i, feature in enumerate(features): + # print("\n\n", i) + # for key, value in feature.items(): + # print(key, value.shape, value.dtype) + + all_labels = [{k: f[k].copy() for k in ["aatype_pdb", "all_atom_positions", "all_atom_mask"]} for f in features] + + asym_len = np.array([c["seq_length"] for c in features], dtype=np.int64) + + features = pair_and_merge(features) + features = post_process(features) + features["asym_len"] = asym_len + processed_feature = Feature(self.data_cfg, features, is_training=True, is_multimer=True) + + seed = global_seed() + input_arrays, prev_pos, prev_msa_first_row, prev_pair, num_sym, bert_mask, true_msa \ + = processed_feature.pipeline(self.data_cfg, self.mixed_precision, seed=seed) + + + all_labels = process_labels(all_labels) + # print(f"\n\n==========================all_labels") + # for key, value in all_labels[0].items(): + # print(key, value.shape, value.dtype, flush=True) + # keys = list(all_labels[0].keys()) + # print(keys) + # keys.sort() + # for i, all_label in enumerate(all_labels): + # print("\n\n\n===============", i) + # for key in OUTPUT_LABEL_KEYS: + # value = all_label[key] + # print(key, value.shape, value.dtype, flush=True) + + def merge_label_dicts(all_labels): + labels_arrays = [] + for key in OUTPUT_LABEL_KEYS: + values = [] + for all_label in all_labels: + values.append(all_label[key]) + value = np.concatenate(values, axis=0) + if value.dtype == "float64": + value = value.astype(np.float16) + if value.dtype == "float32": + value = value.astype(np.float16) + if value.dtype == "int64": + value = value.astype(np.int32) + labels_arrays.append(value) + return labels_arrays + + labels_arrays = merge_label_dicts(all_labels) + # for array in labels_arrays: + # print(array.shape, array.dtype) + + return input_arrays, prev_pos, prev_msa_first_row, prev_pair, num_sym, bert_mask, true_msa, labels_arrays + + +class SeedMaker: + """Return unique seeds.""" + + def __init__(self, initial_seed=0): + self.next_seed = initial_seed + + def __call__(self): + i = self.next_seed + self.next_seed += 1 + return i + + +global_seed = SeedMaker() + + +def process_pdb(true_aatype, ori_res_length, decoy_pdb_path): + """get atom information from pdb""" + with open(decoy_pdb_path, 'r') as f: + decoy_prot_pdb = from_pdb_string(f.read()) + f.close() + decoy_aatype = decoy_prot_pdb.aatype + decoy_atom37_positions = decoy_prot_pdb.atom_positions.astype(np.float32) + decoy_atom37_mask = decoy_prot_pdb.atom_mask.astype(np.float32) + padding_val = true_aatype.shape[0] - ori_res_length + true_aatype = true_aatype[:ori_res_length] + decoy_aatype, decoy_atom37_positions, decoy_atom37_mask, align_mask = \ + align_with_aatype(true_aatype, decoy_aatype, decoy_atom37_positions, decoy_atom37_mask) + decoy_atom37_positions = np.pad(decoy_atom37_positions, ((0, padding_val), (0, 0), (0, 0))) + decoy_atom37_mask = np.pad(decoy_atom37_mask, ((0, padding_val), (0, 0))) + align_mask = np.pad(align_mask, ((0, padding_val))) + + return decoy_atom37_positions, decoy_atom37_mask, align_mask + + +def align_with_aatype(true_aatype, aatype, atom37_positions, atom37_mask): + """align pdb with aatype""" + if len(true_aatype) == len(aatype): + out = aatype, atom37_positions, atom37_mask, np.ones((aatype.shape[0])).astype(np.float32) + return out + seq1 = [order_restype_with_x.get(x) for x in aatype] + seq2 = [order_restype_with_x.get(x) for x in true_aatype] + seq1 = ''.join(seq1) + seq2 = ''.join(seq2) + _, align_relationship, _ = get_aligned_seq(seq1, seq2) + pdb_index = 0 + seq_len = len(true_aatype) + new_aatype = np.zeros((seq_len,)).astype(np.int32) + new_atom37_positions = np.zeros((seq_len, 37, 3)).astype(np.float32) + new_atom37_mask = np.zeros((seq_len, 37)).astype(np.float32) + align_mask = np.zeros((seq_len,)).astype(np.float32) + for i in range(len(true_aatype)): + if align_relationship[i] == "-": + new_aatype[i] = 20 + new_atom37_positions[i] = np.zeros((37, 3)).astype(np.float32) + new_atom37_mask[i] = np.zeros((37,)).astype(np.float32) + align_mask[i] = 0 + else: + new_aatype[i] = aatype[pdb_index] + new_atom37_positions[i] = atom37_positions[pdb_index] + new_atom37_mask[i] = atom37_mask[pdb_index] + align_mask[i] = 1 + pdb_index += 1 + out = new_aatype, new_atom37_positions, new_atom37_mask, align_mask + return out diff --git a/MindSPONGE/applications/research/Grasp/data/multimer_pipeline.py b/MindSPONGE/applications/research/Grasp/data/multimer_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..5a11de2a715ebec8d80c2c53b04bb58b7f6e4bd0 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/multimer_pipeline.py @@ -0,0 +1,715 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""multimer data preprocess pipeline""" + +import collections +import numpy as np +import pandas as pd +import scipy.linalg + +from mindsponge1.common import residue_constants +from mindsponge1.data.data_transform import process_unmerged_features, get_crop_size, correct_msa_restypes, \ + make_seq_mask, make_msa_mask, add_padding + +REQUIRED_FEATURES = frozenset({ + 'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids', + 'all_crops_all_chains_mask', 'all_crops_all_chains_positions', + 'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id', + 'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean', + 'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments', + 'num_templates', 'queue_size', 'residue_index', 'resolution', + 'seq_length', 'seq_mask', 'sym_id', 'template_aatype', + 'template_all_atom_mask', 'template_all_atom_positions', + "asym_len", "template_sum_probs", "num_sym", "msa_chains" # dyh +}) +MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') +TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions', + 'template_all_atom_mask') +SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions', + 'all_atom_mask', 'seq_mask', 'between_segment_residues', + 'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id', + 'sym_id', 'entity_mask', 'deletion_mean', + 'prediction_atom_mask', + 'literature_positions', 'atom_indices_to_group_indices', + 'rigid_group_default_frame', "num_sym") # dyh +CHAIN_FEATURES = ('num_alignments', 'seq_length') +MAX_TEMPLATES = 4 +MSA_CROP_SIZE = 2048 + + +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def add_assembly_features(all_chain_features): + """Add features to distinguish between chains. + + Args: + all_chain_features: A dictionary which maps chain_id to a dictionary of + features for each chain. + + Returns: + all_chain_features: A dictionary which maps strings of the form + `_` to the corresponding chain features. E.g. two + chains from a homodimer would have keys A_1 and A_2. Two chains from a + heterodimer would have keys A_1 and B_1. + """ + # Group the chains by sequence + seq_to_entity_id = {} + grouped_chains = {} + # for chain_id, chain_features in all_chain_features.items(): + for chain_features in all_chain_features: + seq = str(chain_features['sequence']) + if seq not in seq_to_entity_id: + seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 + entity_id_x = seq_to_entity_id.get(seq) + if entity_id_x not in grouped_chains: + grouped_chains[entity_id_x] = [] + grouped_chains.get(entity_id_x).append(chain_features) + + new_all_chain_features = [] + chain_id = 1 + for entity_id, group_chain_features in grouped_chains.items(): + num_sym = len(group_chain_features) # dyh + for sym_id, chain_features in enumerate(group_chain_features, start=1): + # new_all_chain_features[ + # f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features + seq_length = chain_features['seq_length'] + chain_features['asym_id'] = chain_id * np.ones(seq_length).astype(np.int32) + chain_features['sym_id'] = sym_id * np.ones(seq_length).astype(np.int32) + chain_features['entity_id'] = entity_id * np.ones(seq_length).astype(np.int32) + chain_features["num_sym"] = num_sym * np.ones(seq_length).astype(np.int32) # dyh + chain_id += 1 + new_all_chain_features.append(chain_features) + + return new_all_chain_features + + +def _is_homomer_or_monomer(chains) -> bool: + """Checks if a list of chains represents a homomer/monomer example.""" + # Note that an entity_id of 0 indicates padding. + num_unique_chains = len(np.unique(np.concatenate( + [np.unique(chain['entity_id'][chain['entity_id'] > 0]) for chain in chains]))) + return num_unique_chains == 1 or "msa_all_seq" not in chains[0] + + +def _make_msa_df(chain_features): + """Makes dataframe with msa features needed for msa pairing.""" + chain_msa = chain_features['msa_all_seq'] + query_seq = chain_msa[0] + per_seq_similarity = np.sum(query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) + per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) + msa_df = pd.DataFrame({ + 'msa_species_identifiers': chain_features['msa_species_identifiers_all_seq'], + 'msa_row': np.arange(len(chain_features['msa_species_identifiers_all_seq'])), + 'msa_similarity': per_seq_similarity, + 'gap': per_seq_gap + }) + return msa_df + + +def _create_species_dict(msa_df): + """Creates mapping from species to msa dataframe of that species.""" + species_lookup = {} + for species, species_df in msa_df.groupby('msa_species_identifiers'): + species_lookup[species] = species_df + return species_lookup + + +def _match_rows_by_sequence_similarity(this_species_msa_dfs): + """Finds MSA sequence pairings across chains based on sequence similarity. + + Each chain's MSA sequences are first sorted by their sequence similarity to + their respective target sequence. The sequences are then paired, starting + from the sequences most similar to their target sequence. + + Args: + this_species_msa_dfs: a list of dataframes containing MSA features for + sequences for a specific species. + + Returns: + A list of lists, each containing M indices corresponding to paired MSA rows, + where M is the number of chains. + """ + all_paired_msa_rows = [] + + num_seqs = [len(species_df) for species_df in this_species_msa_dfs if species_df is not None] + take_num_seqs = np.min(num_seqs) + + sort_by_similarity = (lambda x: x.sort_values('msa_similarity', axis=0, ascending=False)) + + for species_df in this_species_msa_dfs: + if species_df is not None: + species_df_sorted = sort_by_similarity(species_df) + msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values + else: + msa_rows = [-1] * take_num_seqs # take the last 'padding' row + all_paired_msa_rows.append(msa_rows) + all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) + return all_paired_msa_rows + + +def pair_sequences(examples): + """Returns indices for paired MSA sequences across chains.""" + + num_examples = len(examples) + + all_chain_species_dict = [] + common_species = set() + for chain_features in examples: + msa_df = _make_msa_df(chain_features) + species_dict = _create_species_dict(msa_df) + all_chain_species_dict.append(species_dict) + common_species.update(set(species_dict)) + + common_species = sorted(common_species) + common_species.remove(b'') # Remove target sequence species. + + all_paired_msa_rows = [np.zeros(len(examples), int)] + all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} + all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] + + for species in common_species: + if not species: + continue + this_species_msa_dfs = [] + species_dfs_present = 0 + for species_dict in all_chain_species_dict: + if species in species_dict: + this_species_msa_dfs.append(species_dict[species]) + species_dfs_present += 1 + else: + this_species_msa_dfs.append(None) + + # Skip species that are present in only one chain. + if species_dfs_present <= 1: + continue + + if np.any( + np.array([len(species_df) for species_df in this_species_msa_dfs if + isinstance(species_df, pd.DataFrame)]) > 600): + continue + + paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs) + all_paired_msa_rows.extend(paired_msa_rows) + all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) + all_paired_msa_rows_dict = { + num_examples: np.array(paired_msa_rows) for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() + } + return all_paired_msa_rows_dict + + +def reorder_paired_rows(all_paired_msa_rows_dict): + """Creates a list of indices of paired MSA rows across chains. + + Args: + all_paired_msa_rows_dict: a mapping from the number of paired chains to the + paired indices. + + Returns: + a list of lists, each containing indices of paired MSA rows across chains. + The paired-index lists are ordered by: + 1) the number of chains in the paired alignment, i.e, all-chain pairings + will come first. + 2) e-values + """ + all_paired_msa_rows = [] + + for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): + paired_rows = all_paired_msa_rows_dict[num_pairings] + paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows])) + paired_rows_sort_index = np.argsort(paired_rows_product) + all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) + + return np.array(all_paired_msa_rows) + + +def pad_features(feature, feature_name): + """Add a 'padding' row at the end of the features list. + + The padding row will be selected as a 'paired' row in the case of partial + alignment - for the chain that doesn't have paired alignment. + + Args: + feature: The feature to be padded. + feature_name: The name of the feature to be padded. + + Returns: + The feature with an additional padding row. + """ + assert feature.dtype != np.dtype(np.string_) + if feature_name in ('msa_all_seq', 'msa_mask_all_seq', 'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'): + padding = add_padding(feature_name, feature) + elif feature_name == 'msa_species_identifiers_all_seq': + padding = [b''] + else: + return feature + feats_padded = np.concatenate([feature, padding], axis=0) + return feats_padded + + +def create_paired_features(chains): + """Returns the original chains with paired NUM_SEQ features. + + Args: + chains: A list of feature dictionaries for each chain. + + Returns: + A list of feature dictionaries with sequence features including only + rows to be paired. + """ + chains = list(chains) + chain_keys = chains[0].keys() + + if len(chains) < 2: + return chains + updated_chains = [] + paired_chains_to_paired_row_indices = pair_sequences(chains) + paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices) + + for chain_num, chain in enumerate(chains): + new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} + for feature_name in chain_keys: + if feature_name.endswith('_all_seq'): + feats_padded = pad_features(chain[feature_name], feature_name) + new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]] + new_chain['num_alignments_all_seq'] = np.asarray(len(paired_rows[:, chain_num])) + updated_chains.append(new_chain) + return updated_chains + + +def deduplicate_unpaired_sequences(np_chains): + """Removes unpaired sequences which duplicate a paired sequence.""" + + feature_names = np_chains[0].keys() + msa_features = MSA_FEATURES + cache_msa_features = {} + for chain in np_chains: + entity_id = int(chain["entity_id"][0]) + if entity_id not in cache_msa_features: + # Convert the msa_all_seq numpy array to a tuple for hashing. + sequence_set = set(tuple(s) for s in chain['msa_all_seq']) + keep_rows = [] + # Go through unpaired MSA seqs and remove any rows that correspond to the + # sequences that are already present in the paired MSA. + for row_num, seq in enumerate(chain['msa']): + if tuple(seq) not in sequence_set: + keep_rows.append(row_num) + new_msa_features = {} + for feature_name in feature_names: + if feature_name in msa_features: + if keep_rows: + new_msa_features[feature_name] = chain[feature_name][keep_rows] + else: + new_shape = list(chain[feature_name].shape) + new_shape[0] = 0 + new_msa_features[feature_name] = np.zeros(new_shape, dtype=chain[feature_name].dtype) + cache_msa_features[entity_id] = new_msa_features + for feature_name in cache_msa_features[entity_id]: + chain[feature_name] = cache_msa_features[entity_id][feature_name] + chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32) + return np_chains + + +def _crop_single_chain(chain, + msa_crop_size, + max_templates): + """Crops msa sequences to `msa_crop_size`.""" + msa_size = chain['num_alignments'] + pair_msa_sequences = "num_alignments_all_seq" in chain.keys() + if pair_msa_sequences: + msa_crop_size, msa_crop_size_all_seq = get_crop_size(chain["num_alignments_all_seq"], chain["msa_all_seq"], + msa_crop_size, msa_size) + else: + msa_crop_size = np.minimum(msa_size, msa_crop_size) + + include_templates = "template_aatype" in chain and max_templates + if include_templates: + num_templates = chain['template_aatype'].shape[0] + templates_crop_size = np.minimum(num_templates, max_templates) + + for k in chain: + k_split = k.split('_all_seq')[0] + if k_split in TEMPLATE_FEATURES: + chain[k] = chain[k][:templates_crop_size] + elif k_split in MSA_FEATURES: + if '_all_seq' in k and pair_msa_sequences: + chain[k] = chain[k][:msa_crop_size_all_seq] + else: + chain[k] = chain[k][:msa_crop_size] + + chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32) + if include_templates: + chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32) + if pair_msa_sequences: + chain['num_alignments_all_seq'] = np.asarray(msa_crop_size_all_seq, dtype=np.int32) + return chain + + +def crop_chains( + chains_list, + msa_crop_size, + max_templates): + """Crops the MSAs for a set of chains. + + Args: + chains_list: A list of chains to be cropped. + msa_crop_size: The total number of sequences to crop from the MSA. + pair_msa_sequences: Whether we are operating in sequence-pairing mode. + max_templates: The maximum templates to use per chain. + + Returns: + The chains cropped. + """ + + # Apply the cropping. + cropped_chains = [] + for chain in chains_list: + cropped_chain = _crop_single_chain( + chain, + msa_crop_size=msa_crop_size, + max_templates=max_templates) + cropped_chains.append(cropped_chain) + + return cropped_chains + + +def _pad_templates(chains, + max_templates): + """For each chain pad the number of templates to a fixed size. + + Args: + chains: A list of protein chains. + max_templates: Each chain will be padded to have this many templates. + + Returns: + The list of chains, updated to have template features padded to + max_templates. + """ + for chain in chains: + for k, v in chain.items(): + if k in TEMPLATE_FEATURES: + padding = np.zeros_like(v.shape) + padding[0] = max_templates - v.shape[0] + padding = [(0, p) for p in padding] + chain[k] = np.pad(v, padding, mode='constant') + return chains + + +def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: + """Like scipy.linalg.block_diag but with an optional padding value.""" + ones_arrs = [np.ones_like(x) for x in arrs] + off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs) + diag = scipy.linalg.block_diag(*arrs) + diag += (off_diag_mask * pad_value).astype(diag.dtype) + return diag + + +def _merge_features_from_multiple_chains(chains): + """Merge features from multiple chains. + + Args: + chains: A list of feature dictionaries that we want to merge. + pair_msa_sequences: Whether to concatenate MSA features along the + num_res dimension (if True), or to block diagonalize them (if False). + + Returns: + A feature dictionary for the merged example. + """ + merged_example = {} + for feature_name in chains[0]: + feats = [x[feature_name] for x in chains] + feature_name_split = feature_name.split('_all_seq')[0] + if feature_name_split in MSA_FEATURES: + if '_all_seq' in feature_name: + merged_example[feature_name] = np.concatenate(feats, axis=1) + if feature_name_split == "msa": + merged_example["msa_chains_all_seq"] = np.ones( + merged_example[feature_name].shape[0] + ).reshape(-1, 1) + else: + merged_example[feature_name] = block_diag( + *feats, pad_value=residue_constants.MSA_PAD_VALUES[feature_name]) + #### dyh + if feature_name_split == "msa": + msa_chains = [] + for i, feat in enumerate(feats): + cur_shape = feat.shape[0] + vals = np.ones(cur_shape) * (i + 2) + msa_chains.append(vals) + merged_example["msa_chains"] = np.concatenate(msa_chains).reshape( + -1, 1 + ) + #### + elif feature_name_split in SEQ_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=0) + elif feature_name_split in TEMPLATE_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=1) + elif feature_name_split in CHAIN_FEATURES: + merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32) + else: + merged_example[feature_name] = feats[0] + return merged_example + + +def _concatenate_paired_and_unpaired_features(example): + """Merges paired and block-diagonalised features.""" + features = MSA_FEATURES + ("msa_chains",) # dyh + for feature_name in features: + if feature_name in example: + feat = example[feature_name] + feat_all_seq = example[feature_name + '_all_seq'] + merged_feat = np.concatenate([feat_all_seq, feat], axis=0) + example[feature_name] = merged_feat + example['num_alignments'] = np.array(example['msa'].shape[0], dtype=np.int32) + return example + + +def _correct_post_merged_feats( + np_example, + np_chains_list, + pair_msa_sequences): + """Adds features that need to be computed/recomputed post merging.""" + + np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0], dtype=np.int32) + np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0], dtype=np.int32) + + if not pair_msa_sequences: + # Generate a bias that is 1 for the first row of every block in the + # block diagonal MSA - i.e. make sure the cluster stack always includes + # the query sequences for each chain (since the first row is the query + # sequence). + cluster_bias_masks = [] + for chain in np_chains_list: + mask = np.zeros(chain['msa'].shape[0]) + mask[0] = 1 + cluster_bias_masks.append(mask) + np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for x in np_chains_list] # int8 dyh + np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0) + else: + np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) + np_example['cluster_bias_mask'][0] = 1 + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for x in np_chains_list] # int8 dyh + msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for x in np_chains_list] # int8 dyh + + msa_mask_block_diag = block_diag(*msa_masks, pad_value=0) + msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) + np_example['bert_mask'] = np.concatenate([msa_mask_all_seq, msa_mask_block_diag], axis=0) + return np_example + + +def merge_chain_features(np_chains_list, + max_templates): + """Merges features for multiple chains to single FeatureDict. + + Args: + np_chains_list: List of FeatureDicts for each chain. + pair_msa_sequences: Whether to merge paired MSAs. + max_templates: The maximum number of templates to include. + + Returns: + Single FeatureDict for entire complex. + """ + np_chains_list = _pad_templates(np_chains_list, max_templates=max_templates) + + np_example = _merge_features_from_multiple_chains(np_chains_list) + + pair_msa_sequences = "msa_all_seq" in np_example.keys() + if pair_msa_sequences: + np_example = _concatenate_paired_and_unpaired_features(np_example) + + np_example = _correct_post_merged_feats( + np_example=np_example, + np_chains_list=np_chains_list, + pair_msa_sequences=pair_msa_sequences) + + return np_example + + +def _filter_features(np_example): + """Filters features of example to only those requested.""" + return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES} + + +def process_final(np_example): + """Final processing steps in data pipeline, after merging and pairing.""" + # np_example["msa"] = correct_msa_restypes(np_example["msa"]) + np_example["seq_mask"] = make_seq_mask(np_example["entity_id"]) + np_example["msa_mask"] = make_msa_mask(np_example["msa"], np_example["entity_id"]) + np_example = _filter_features(np_example) + return np_example + + +def pair_and_merge(all_chain_features): + """Runs processing on features to augment, pair and merge. + + Args: + all_chain_features: A MutableMap of dictionaries of features for each chain. + + Returns: + A dictionary of features. + """ + + num_chains = len(all_chain_features) + for chain_features in all_chain_features: + # Convert deletion matrices to float. + if "deletion_matrix_int" in chain_features: + chain_features["deletion_matrix"] = np.asarray(chain_features.pop("deletion_matrix_int"), dtype=np.float32) + + chain_features["deletion_mean"] = np.mean(chain_features["deletion_matrix"], axis=0) + + # Add assembly_num_chains. + chain_features["assembly_num_chains"] = np.asarray(num_chains) + + # Add entity_mask. + for chain_features in all_chain_features: + chain_features["entity_mask"] = (chain_features["entity_id"] != 0).astype(np.int32) + + np_chains_list = all_chain_features + + np_chains_list = crop_chains( + np_chains_list, + msa_crop_size=MSA_CROP_SIZE, #2048 + max_templates=MAX_TEMPLATES) #4 + np_example = merge_chain_features( + np_chains_list=np_chains_list, + max_templates=MAX_TEMPLATES) + np_example = process_final(np_example) + return np_example + + +def pad_msa(np_example, min_num_seq): + """ padding features with 0 if seq number less than min_num_seq. + + Args: + np_example: A feature dict with msa, deletion_matrix, bert_mask, msa_mask and cluster_bias_mask. + min_num_seq: minimal sequence number + + Returns: + np_example: padded with 0 features include msa, deletion_matrix, bert_mask, msa_mask and cluster_bias_mask. + + """ + + np_example = dict(np_example) + num_seq = np_example['msa'].shape[0] + if num_seq < min_num_seq: + for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask', "msa_chains"): + np_example[feat] = np.pad(np_example[feat], ((0, min_num_seq - num_seq), (0, 0))) + np_example['cluster_bias_mask'] = np.pad(np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),)) + return np_example + + +# These four functions are from Unifold Multimer + +def empty_template_feats(n_res): + return { + "template_aatype": np.zeros((0, n_res)).astype(np.int64), + "template_all_atom_positions": np.zeros((0, n_res, 37, 3)).astype(np.float32), + "template_sum_probs": np.zeros((0, 1)).astype(np.float32), + "template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32), + } + + +def uconvert_monomer_features(monomer_features): + """Reshapes and modifies monomer features for multimer models.""" + if monomer_features["template_aatype"].shape[0] == 0: + monomer_features.update( + empty_template_feats(monomer_features["aatype"].shape[0]) + ) + converted = {} + unnecessary_leading_dim_feats = { + "sequence", + "domain_name", + "num_alignments", + "seq_length", + } + for feature_name, feature in monomer_features.items(): + if feature_name in unnecessary_leading_dim_feats: + # asarray ensures it's a np.ndarray. + feature = np.asarray(feature[0], dtype=feature.dtype) + elif feature_name == "aatype": + # The multimer model performs the one-hot operation itself. + feature = np.argmax(feature, axis=-1).astype(np.int32) + elif feature_name == "template_aatype": + if feature.shape[0] > 0: + feature = correct_template_restypes(feature) + elif feature_name == "template_all_atom_masks": + feature_name = "template_all_atom_mask" + elif feature_name == "msa": + feature = feature.astype(np.uint8) + + if feature_name.endswith("_mask"): + feature = feature.astype(np.float32) + + converted[feature_name] = feature + + if "deletion_matrix_int" in monomer_features: + monomer_features["deletion_matrix"] = monomer_features.pop( + "deletion_matrix_int" + ).astype(np.float32) + + converted.pop( + "template_sum_probs" + ) + return converted + + +def post_process(np_example): + np_example = pad_msa(np_example, 512) + no_dim_keys = [ + "num_alignments", + "assembly_num_chains", + "num_templates", + "seq_length", + "resolution", + ] + for k in no_dim_keys: + if k in np_example: + np_example[k] = np_example[k].reshape(-1) + return np_example + + +def merge_msas(msa, del_mat, new_msa, new_del_mat): + cur_msa_set = set([tuple(m) for m in msa]) + new_rows = [] + for i, s in enumerate(new_msa): + if tuple(s) not in cur_msa_set: + new_rows.append(i) + ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0) + ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0) + return ret_msa, ret_del_mat \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/data/multimer_process.py b/MindSPONGE/applications/research/Grasp/data/multimer_process.py new file mode 100644 index 0000000000000000000000000000000000000000..0011cee8dbd5094868145d9846d950bcf135227b --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/multimer_process.py @@ -0,0 +1,456 @@ +import numpy as np + +from mindsponge1.data.data_transform import make_atom14_masks, \ + atom37_to_frames, atom37_to_torsion_angles, pseudo_beta_fn, to_tensor_4x4 +from mindsponge1.common.utils import make_atom14_positions +from mindsponge1.common.residue_constants import atom_order + +from data.utils import numpy_seed + + +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' + + +def make_pseudo_beta(protein, prefix=""): + """Create pseudo-beta (alpha for glycine) position and mask.""" + assert prefix in ["", "template_"] + ( + protein[prefix + "pseudo_beta"], + protein[prefix + "pseudo_beta_mask"], + ) = pseudo_beta_fn( + protein["template_aatype" if prefix else "aatype"], + protein[prefix + "all_atom_positions"], + protein["template_all_atom_mask" if prefix else "all_atom_mask"], + ) + return protein + + +def get_pairwise_distances(coords): + coord_diff = np.expand_dims(coords, axis=-2) - np.expand_dims(coords, axis=-3) + return np.sqrt(np.sum(coord_diff**2, axis=-1)) + + +def get_interface_candidates_v2(ca_distances, asym_id, pair_mask, ca_ca_threshold): + + in_same_asym = asym_id[..., None] == asym_id[..., None, :] + # set distance in the same entity to zero + ca_distances = ca_distances * (1.0 - in_same_asym.astype(np.float32)) * pair_mask + interface_candidates = np.array(np.nonzero((ca_distances > 0) & (ca_distances < ca_ca_threshold))).transpose() + # print("interface_candidates", interface_candidates) + return interface_candidates + + +def get_interface_candidates(ca_distances, asym_id, pair_mask, ca_ca_threshold): + + in_same_asym = asym_id[..., None] == asym_id[..., None, :] + # set distance in the same entity to zero + ca_distances = ca_distances * (1.0 - in_same_asym.astype(np.float32)) * pair_mask + cnt_interfaces = np.sum((ca_distances > 0) & (ca_distances < ca_ca_threshold), axis=-1) + interface_candidates = np.nonzero(cnt_interfaces)[0] + return interface_candidates + + +def get_crop_sizes_each_chain(asym_len, crop_size, random_seed=None, use_multinomial=False,): + """get crop sizes for contiguous crop""" + if not use_multinomial: + with numpy_seed(random_seed, key="multimer_contiguous_perm"): + shuffle_idx = np.random.permutation(len(asym_len)) + num_left = np.array(asym_len.sum()) + num_budget = np.array(crop_size) + crop_sizes = [0 for _ in asym_len] + for j, idx in enumerate(shuffle_idx): + this_len = asym_len[idx] + num_left -= this_len + # num res at most we can keep in this ent + max_size = min(num_budget, this_len) + # num res at least we shall keep in this ent + min_size = min(this_len, max(0, num_budget - num_left)) + with numpy_seed(random_seed, j, key="multimer_contiguous_crop_size"): + this_crop_size = int(np.random.randint(low=int(min_size), high=int(max_size) + 1)) + num_budget -= this_crop_size + crop_sizes[idx] = this_crop_size + crop_sizes = np.array(crop_sizes) + else: # use multinomial + # TODO: better multimer + entity_probs = asym_len / np.sum(asym_len) + crop_sizes = np.random.multinomial(crop_size, pvals=entity_probs) + crop_sizes = np.min(crop_sizes, asym_len) + return crop_sizes + + +def get_contiguous_crop_idx(protein, crop_size, random_seed=None, use_multinomial=False): + + num_res = protein["aatype"].shape[0] + if num_res <= crop_size: + return np.arange(num_res) + + assert "asym_len" in protein + asym_len = protein["asym_len"] + + crop_sizes = get_crop_sizes_each_chain(asym_len, crop_size, random_seed, use_multinomial) + crop_idxs = [] + asym_offset = np.array(0, dtype=np.int64) + with numpy_seed(random_seed, key="multimer_contiguous_crop_start_idx"): + for l, csz in zip(asym_len, crop_sizes): + this_start = np.random.randint(0, int(l - csz) + 1) + crop_idxs.append(np.arange(asym_offset + this_start, asym_offset + this_start + csz)) + asym_offset += l + + return np.concatenate(crop_idxs) + + + +def random_num_with_fix_total(maxValue, num): + # generate 'num - 1' uniformlly distributed integers to split + # the whole interval [0, maxValue] into 'num' small intervals + a = list(np.random.uniform(0, maxValue, size=(num-1)).astype(np.int32)) + a.append(0) + a.append(maxValue) + a = sorted(a) + b = [ a[i]-a[i-1] for i in range(1, len(a)) ] + # print(b) + return b + + +def get_chain_index(nk_all, res_index_all): + for i, seq_length in enumerate(nk_all): + if res_index_all < seq_length: + return i, res_index_all + else: + res_index_all -= seq_length + + +def contact_biased_continous_cropping(chain_lengths, N_res, selected_contacts): + + minimum_crop_size = 16 + nk_all = [] + random_crop_masks = [] + all_seq_length = 0 + for seq_len in chain_lengths: + nk_all.append(seq_len) + random_crop_masks.append(np.zeros(seq_len,)) + all_seq_length += seq_len + + if all_seq_length <= N_res: + random_crop_masks = [np.ones(mask.shape) for mask in random_crop_masks] + return np.concatenate(random_crop_masks) + + num_contacts = selected_contacts.shape[0] + used_contact = [] + for i in range(num_contacts * 2): + + # get res info in chain + res_index_all = selected_contacts[i // 2, i % 2] + chain_index, res_index_in_chain = get_chain_index(nk_all, res_index_all) + + # get crop size + n_added = int(sum([mask.sum() for mask in random_crop_masks])) + n_left = N_res - n_added + if n_left < minimum_crop_size: + break + randoms = random_num_with_fix_total(n_left - minimum_crop_size, num_contacts * 2 - i) + cur_crop_size = min(randoms[0] + minimum_crop_size, nk_all[chain_index]) + + # get crop start & stop from contact infos + random_start = min(max(res_index_in_chain - cur_crop_size + minimum_crop_size // 2, 0), nk_all[chain_index] - cur_crop_size) + random_stop = min(max(res_index_in_chain - minimum_crop_size // 2 + 1, 0), nk_all[chain_index] - cur_crop_size) + # print(random_start, random_stop) + crop_start = int(np.random.uniform(random_start, random_stop)) + # print(nk_all[chain_index], res_index_in_chain, crop_start, cur_crop_size) + keep = [i for i in range(crop_start, crop_start + cur_crop_size)] + # print(res_index_all, chain_index, res_index_in_chain, crop_start, len(keep)) + random_crop_masks[chain_index][keep] = 1 + + if i % 2 == 1: + used_contact.append(i//2) + # print("used_contact") + # print("len(used_contact)", len(used_contact)) + return np.concatenate(random_crop_masks) + + +def get_chain_lengths(protein): + + last_asym_id = -1 + chain_length = 0 + chain_lengths = [] + for asym_id in protein["asym_id"]: + if asym_id != last_asym_id: + last_asym_id = asym_id + chain_length = 1 + chain_lengths.append(1) + else: + chain_length += 1 + chain_lengths[-1] = chain_length + asym_id = protein["asym_id"] + chain_lengths2 = (asym_id[None, :] == np.array(sorted(list(set(list(asym_id)))))[:, None]).sum(-1) + + if np.sum(np.abs(chain_lengths - chain_lengths2)) > 0: + print("error !!!") + print(list(chain_lengths)) + print(list(chain_lengths2)) + return chain_lengths + + +def get_spatial_crop_idx_v2(protein, crop_size, random_seed, ca_ca_threshold, inf=3e4): + + ca_idx = atom_order["CA"] + ca_coords = protein["all_atom_positions"][..., ca_idx, :] + ca_mask = protein["all_atom_mask"][..., ca_idx].astype(np.bool) + # if there are not enough atoms to construct interface, use contiguous crop + if (ca_mask.sum(axis=-1) <= 1).all(): + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + pair_mask = ca_mask[..., None] * ca_mask[..., None, :] + ca_distances = get_pairwise_distances(ca_coords) + + interface_candidates = get_interface_candidates_v2(ca_distances, + protein["asym_id"], + pair_mask, + ca_ca_threshold) + + if interface_candidates.any(): + with numpy_seed(random_seed, key="multimer_spatial_crop"): + np.random.shuffle(interface_candidates) + else: + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + chain_lengths = get_chain_lengths(protein) + + random_masks_all = contact_biased_continous_cropping(chain_lengths, crop_size, interface_candidates) + ret = list(np.where(np.array(random_masks_all) > 0)[0]) + return ret + + +def get_spatial_crop_idx(protein, crop_size, random_seed, ca_ca_threshold, inf=3e4): + + ca_idx = atom_order["CA"] + ca_coords = protein["all_atom_positions"][..., ca_idx, :] + ca_mask = protein["all_atom_mask"][..., ca_idx].astype(np.bool) + # if there are not enough atoms to construct interface, use contiguous crop + if (ca_mask.sum(axis=-1) <= 1).all(): + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + pair_mask = ca_mask[..., None] * ca_mask[..., None, :] + ca_distances = get_pairwise_distances(ca_coords) + + interface_candidates = get_interface_candidates(ca_distances, + protein["asym_id"], + pair_mask, + ca_ca_threshold) + + if interface_candidates.any(): + with numpy_seed(random_seed, key="multimer_spatial_crop"): + target_res = int(np.random.choice(interface_candidates)) + else: + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + to_target_distances = ca_distances[target_res] + # set inf to non-position residues + to_target_distances[~ca_mask] = inf + break_tie = (np.arange(0, to_target_distances.shape[-1], dtype=np.float32) * 1e-3) + to_target_distances += break_tie + ret = np.argsort(to_target_distances)[:crop_size] + ret.sort() + return ret + + +def apply_crop_idx(protein, shape_schema, crop_idx): + cropped_protein = {} + for k, v in protein.items(): + if k not in shape_schema: # skip items with unknown shape schema + continue + for i, dim_size in enumerate(shape_schema[k]): + if dim_size == NUM_RES: + v = np.take(v, crop_idx, axis=i) + cropped_protein[k] = v + return cropped_protein + + +def select_feat(protein, feature_list): + feature_list.pop("msa") + feature_list.pop("msa_chains") + feature_list.pop("deletion_matrix") + feature_list.pop("num_alignments") + feature_list.pop("hhblits_profile") + return {k: v for k, v in protein.items() if k in feature_list} + + +def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num_res=0, num_templates=0,): + """Guess at the MSA and sequence dimension to make fixed size.""" + def get_pad_size(cur_size, multiplier=4): + return max(multiplier, + ((cur_size + multiplier - 1) // multiplier) * multiplier + ) + if num_res is not None: + input_num_res = ( + protein["aatype"].shape[0] + if "aatype" in protein + else protein["msa_mask"].shape[1] + ) + if input_num_res != num_res: + num_res = get_pad_size(input_num_res, 4) + # if "extra_msa_mask" in protein: + # input_extra_msa_size = protein["extra_msa_mask"].shape[0] + # if input_extra_msa_size != extra_msa_size: + # print(input_extra_msa_size, extra_msa_size) + # # import time + # # time.sleep(100) + # extra_msa_size = get_pad_size(input_extra_msa_size, 8) + pad_size_map = { + NUM_RES: num_res, + NUM_MSA_SEQ: msa_cluster_size, + NUM_EXTRA_SEQ: extra_msa_size, + NUM_TEMPLATES: num_templates, + } + + for k, v in protein.items(): + # Don't transfer this to the accelerator. + if k == "extra_cluster_assignment": + continue + shape = list(v.shape) + schema = shape_schema[k] + msg = "Rank mismatch between shape and shape schema for" + assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}" + pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] + + padding = [] + for i, p in enumerate(pad_size): + if (p - v.shape[i]) >= 0: + padding.append((0, p - v.shape[i])) + else: + padding.append((0, 0)) + v = v.take(np.arange(v.shape[i]+(p - v.shape[i])), axis=i) + if padding: + protein[k] = np.pad(v, padding) + protein[k] = protein[k].reshape(pad_size) + + return protein + + +def pad_then_stack(values): + if len(values[0].shape) >= 1: + size = max(v.shape[0] for v in values) + new_values = [] + for v in values: + if v.shape[0] < size: + res = np.zeros((size, *v.shape[1:]), dtype=values[0].dtype) + res[:v.shape[0], ...] = v + else: + res = v + new_values.append(res) + else: + new_values = values + return np.stack(new_values, axis=0) + + +def map_fn(fun, x): + ensembles = [fun(elem) for elem in x] + features = ensembles[0].keys() + ensembled_dict = {} + for feat in features: + ensembled_dict[feat] = pad_then_stack([dict_i[feat] for dict_i in ensembles]) + return ensembled_dict + +def get_train_labels_old(aatype, atom37_positions, atom37_mask, chain_index): + """get train labels""" + + seq_len = len(aatype) + # get ground truth of atom14 + label_features = {'aatype': aatype, + 'all_atom_positions': atom37_positions, + 'all_atom_mask': atom37_mask} + + atom14_features = make_atom14_positions(aatype, atom37_mask, atom37_positions) + atom14_keys = ["atom14_atom_exists", "atom14_gt_exists", "atom14_gt_positions", "residx_atom14_to_atom37", + "residx_atom37_to_atom14", "atom37_atom_exists", "atom14_alt_gt_positions", + "atom14_alt_gt_exists", "atom14_atom_is_ambiguous"] + for index, array in enumerate(atom14_features): + label_features[atom14_keys[index]] = array + + # get ground truth of rigid groups + rigidgroups_label_feature = atom37_to_frames(aatype, atom37_positions, atom37_mask, is_affine=True) + label_features.update(rigidgroups_label_feature) + + # get ground truth of angle + angle_label_feature = atom37_to_torsion_angles(aatype.reshape((1, -1)), + atom37_positions.reshape((1, seq_len, 37, 3)), + atom37_mask.reshape((1, seq_len, 37)), True) + label_features.update(angle_label_feature) + + # get pseudo_beta, pseudo_beta_mask + pseudo_beta, pseudo_beta_mask = pseudo_beta_fn(aatype, atom37_positions, atom37_mask) + label_features["pseudo_beta"] = pseudo_beta + label_features["pseudo_beta_mask"] = pseudo_beta_mask + label_features["chi_mask"] = label_features.get("torsion_angles_mask")[:, 3:] + label_features['torsion_angles_sin_cos'] = label_features.get('torsion_angles_sin_cos')[:, 3:, :] + label_features['backbone_affine_mask'] = pseudo_beta_mask + label_features["chain_index"] = (np.ones(pseudo_beta_mask.shape) * chain_index).astype(np.int32) + label_features["aatype_per_chain"] = label_features["aatype"] + label_features["atom37_atom_exists_per_chain"] = label_features["atom37_atom_exists"] + # print(np.allclose(label_features["atom37_atom_exists"], label_features["all_atom_mask"])) + # print(label_features["chain_index"]) + + return label_features + +def process_single_label(label, chain_index): + assert "aatype_pdb" in label + assert "all_atom_positions" in label + assert "all_atom_mask" in label + + label_features = get_train_labels_old(label["aatype_pdb"], label["all_atom_positions"], label["all_atom_mask"], chain_index) + + return label_features + +def process_labels(labels_list): + return [process_single_label(l, chain_index) for chain_index, l in enumerate(labels_list)] + +def label_transform_fn(label): + + aatype = label["aatype"] + atom14_atom_exists, residx_atom14_to_atom37, residx_atom37_to_atom14, \ + atom37_atom_exists = make_atom14_masks(aatype) + label["residx_atom14_to_atom37"] = residx_atom14_to_atom37 + label["residx_atom37_to_atom14"] = residx_atom37_to_atom14 + label["atom14_atom_exists"] = atom14_atom_exists + label["atom37_atom_exists"] = atom37_atom_exists + + all_atom_mask = label["all_atom_mask"] + all_atom_positions = label["all_atom_positions"] + atom14_atom_exists, atom14_gt_exists, atom14_gt_positions, _, _, _, \ + atom14_alt_gt_positions, atom14_alt_gt_exists, atom14_atom_is_ambiguous = \ + make_atom14_positions(aatype, all_atom_mask, all_atom_positions) + label["atom14_atom_exists"] = atom14_atom_exists + label["atom14_gt_exists"] = atom14_gt_exists + label["atom14_gt_positions"] = atom14_gt_positions + label["atom14_alt_gt_positions"] = atom14_alt_gt_positions + label["atom14_alt_gt_exists"] = atom14_alt_gt_exists + label["atom14_atom_is_ambiguous"] = atom14_atom_is_ambiguous + + label_f = atom37_to_frames(aatype, all_atom_positions, all_atom_mask) + label["mrigidgroups_gt_frames"] = label_f["rigidgroups_gt_frames"] + label["rigidgroups_gt_exists"] = label_f["rigidgroups_gt_exists"] + label["rigidgroups_group_exists"] = label_f["rigidgroups_group_exists"] + label["rigidgroups_group_is_ambiguous"] = label_f["rigidgroups_group_is_ambiguous"] + label["mrigidgroups_alt_gt_frames"] = label_f["rigidgroups_alt_gt_frames"] + + label["rigidgroups_gt_frames"] = to_tensor_4x4(label["mrigidgroups_gt_frames"]) + label["rigidgroups_alt_gt_frames"] = to_tensor_4x4(label["mrigidgroups_alt_gt_frames"]) + + angle_label_feature = atom37_to_torsion_angles(aatype.reshape((1, -1)), all_atom_positions.reshape((1, -1, 37, 3)), all_atom_mask.reshape((1, -1, 37)), alt_torsions=True) + label["torsion_angles_sin_cos"] = angle_label_feature["torsion_angles_sin_cos"] + label["alt_torsion_angles_sin_cos"] = angle_label_feature["alt_torsion_angles_sin_cos"] + label["torsion_angles_mask"] = angle_label_feature["torsion_angles_mask"] + + label = make_pseudo_beta(label, "") + + label["true_frame_tensor"] = label["rigidgroups_gt_frames"][..., 0, :, :] + label["frame_mask"] = label["rigidgroups_gt_exists"][..., 0] + + dtype = label["all_atom_mask"].dtype + label["chi_angles_sin_cos"] = (label["torsion_angles_sin_cos"][..., 3:, :]).astype(dtype) + label["chi_mask"] = label["torsion_angles_mask"][..., 3:].astype(dtype) + + return label \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/data/parsers.py b/MindSPONGE/applications/research/Grasp/data/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..adb026b105103ca38f69023060cdd3a6e79d6aa0 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/parsers.py @@ -0,0 +1,621 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Read information from a mmcif format file. +""" +import re +import string +import collections +import io +import dataclasses +from typing import Any, Mapping, Optional, Sequence, Tuple, List +from absl import logging +from Bio import PDB +from Bio.Data import SCOPData + + + +@dataclasses.dataclass(frozen=True) +class HhrHit: + """Class representing a hit in an hhr file.""" + index: int + name: str + prob_true: float + e_value: float + score: float + aligned_cols: int + identity: float + similarity: float + sum_probs: float + neff: float + query: str + hit_sequence: str + hit_dssp: str + column_score_code: str + confidence_scores: str + indices_query: List[int] + indices_hit: List[int] + + +# Type aliases: +ChainId = str +PdbHeader = Mapping[str, Any] +PDBSTRUCTURE = PDB.Structure.Structure +SeqRes = str +MmCIFDict = Mapping[str, Sequence[str]] + + +@dataclasses.dataclass(frozen=True) +class Monomer: + id: str + num: int + + +# Note - mmCIF format provides no guarantees on the type of author-assigned +# sequence numbers. They need not be integers. +@dataclasses.dataclass(frozen=True) +class AtomSite: + residue_name: str + author_chain_id: str + mmcif_chain_id: str + author_seq_num: str + mmcif_seq_num: int + insertion_code: str + hetatm_atom: str + model_num: int + + +# Used to map SEQRES index to a residue in the structure. +@dataclasses.dataclass(frozen=True) +class ResiduePosition: + chain_id: str + residue_number: int + insertion_code: str + + +@dataclasses.dataclass(frozen=True) +class ResidueAtPosition: + position: Optional[ResiduePosition] + name: str + is_missing: bool + hetflag: str + + +@dataclasses.dataclass(frozen=True) +class MmcifObject: + """Representation of a parsed mmCIF file. + + Contains: + file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all + files being processed. + header: Biopython header. + structure: Biopython structure. + chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. + {'A': 'ABCDEFG'} + seqres_to_structure: Dict; for each chain_id contains a mapping between + SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, + 1: ResidueAtPosition, + ...}} + raw_string: The raw string used to construct the MmcifObject. + """ + file_id: str + header: PdbHeader + structure: PDBSTRUCTURE + chain_to_seqres: Mapping[ChainId, SeqRes] + seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] + raw_string: Any + + +@dataclasses.dataclass(frozen=True) +class ParsingResult: + """Returned by the parse function. + + Contains: + mmcif_object: A MmcifObject, may be None if no chain could be successfully + parsed. + errors: A dict mapping (file_id, chain_id) to any exception generated. + """ + mmcif_object: Optional[MmcifObject] + errors: Mapping[Tuple[str, str], Any] + + +def _update_hhr_residue_indices_list( + sequence, start_index, indices_list): + """Computes the relative indices for each residue with respect to the original sequence.""" + counter = start_index + for symbol in sequence: + if symbol == '-': + indices_list.append(-1) + else: + indices_list.append(counter) + counter += 1 + + +def _get_hhr_line_regex_groups( + regex_pattern: str, line: str): + match = re.match(regex_pattern, line) + if match is None: + raise RuntimeError(f'Could not parse query line {line}') + return match.groups() + + +def parse_fasta(fasta_string: str): + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith('>'): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append('') + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + + +def _parse_hhr_hit(detailed_lines): + """Parses the detailed HMM HMM comparison section for a single Hit. + + This works on .hhr files generated from both HHBlits and HHSearch. + + Args: + detailed_lines: A list of lines from a single comparison section between 2 + sequences (which each have their own HMM's) + + Returns: + A dictionary with the information from that detailed comparison section + + Raises: + RuntimeError: If a certain line cannot be processed + """ + # Parse first 2 lines. + number_of_hit = int(detailed_lines[0].split()[-1]) + name_hit = detailed_lines[1][1:] + + # Parse the summary line. + pattern = ( + 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' + ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' + ']*Template_Neff=(.*)') + match = re.match(pattern, detailed_lines[2]) + if match is None: + raise RuntimeError( + 'Could not parse section: %s. Expected this: \n%s to contain summary.' % + (detailed_lines, detailed_lines[2])) + (prob_true, e_value, score, aligned_cols, identity, similarity, sum_probs, + neff) = [float(x) for x in match.groups()] + + # The next section reads the detailed comparisons. These are in a 'human + # readable' format which has a fixed length. The strategy employed is to + # assume that each block starts with the query sequence line, and to parse + # that with a regexp in order to deduce the fixed length used for that + # block. + query = '' + hit_sequence = '' + hit_dssp = '' + column_score_code = '' + confidence_scores = '' + indices_query = [] + indices_hit = [] + length_block = None + + for line in detailed_lines[3:]: + # Parse the query sequence line + if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and not line.startswith('Q ss_pred') \ + and not line.startswith('Q Consensus')): + # Thus the first 17 characters must be 'Q ', and we can parse + # everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + + # Get the length of the parsed block using the start and finish indices, + # and ensure it is the same as the actual block length. + start = int(groups[0]) - 1 # Make index zero based. + delta_query = groups[1] + end = int(groups[2]) + num_insertions = len([x for x in delta_query if x == '-']) + length_block = end - start + num_insertions + assert length_block == len(delta_query) + + # Update the query sequence and indices list. + query += delta_query + _update_hhr_residue_indices_list(delta_query, start, indices_query) + + elif line.startswith('T '): + # Parse the hit dssp line. + if line.startswith('T ss_dssp'): + # T ss_dssp hit_dssp + patt = r'T ss_dssp[\t ]*([A-Z-]*)' + groups = _get_hhr_line_regex_groups(patt, line) + assert len(groups[0]) == length_block + hit_dssp += groups[0] + + # Parse the hit sequence. + elif (not line.startswith('T ss_pred') and + not line.startswith('T Consensus')): + # Thus the first 17 characters must be 'T ', and we can + # parse everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + start = int(groups[0]) - 1 # Make index zero based. + delta_hit_sequence = groups[1] + assert length_block == len(delta_hit_sequence) + + # Update the hit sequence and indices list. + hit_sequence += delta_hit_sequence + _update_hhr_residue_indices_list( + delta_hit_sequence, start, indices_hit) + + # Parse the column score line. + elif line.startswith(' ' * 22): + assert length_block + column_score_code += line[22:length_block + 22] + + # Update confidence score. + elif line.startswith('Confidence'): + assert length_block + confidence_scores += line[22:length_block + 22] + + return HhrHit( + index=number_of_hit, + name=name_hit, + prob_true=prob_true, + e_value=e_value, + score=score, + aligned_cols=int(aligned_cols), + identity=identity, + similarity=similarity, + sum_probs=sum_probs, + neff=neff, + query=query, + hit_sequence=hit_sequence, + hit_dssp=hit_dssp, + column_score_code=column_score_code, + confidence_scores=confidence_scores, + indices_query=indices_query, + indices_hit=indices_hit, + ) + + +def parse_hhr(hhr_string: str): + """Parses the content of an entire HHR file.""" + lines = hhr_string.splitlines() + + # Each .hhr file starts with a results table, then has a sequence of hit + # "paragraphs", each paragraph starting with a line 'No '. We + # iterate through each paragraph to parse each hit. + + block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')] + + hits = [] + if block_starts: + block_starts.append(len(lines)) # Add the end of the final block. + for i in range(len(block_starts) - 1): + hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) + return hits + + +def parse_a3m(a3m_string: str): + """Parses sequences and deletion matrix from a3m format alignment. + + Args: + a3m_string: The string contents of a a3m file. The first sequence in the + file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + """ + sequences, _ = parse_fasta(a3m_string) + deletion_matrix = [] + for msa_sequence in sequences: + deletion_vec = [] + deletion_count = 0 + for j in msa_sequence: + if j.islower(): + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + # Make the MSA matrix out of aligned (deletion-free) sequences. + deletion_table = str.maketrans('', '', string.ascii_lowercase) + aligned_sequences = [s.translate(deletion_table) for s in sequences] + return aligned_sequences, deletion_matrix + + +def mmcif_loop_to_list(prefix, parsed_info): + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + cols = [] + data = [] + for key, value in parsed_info.items(): + if key.startswith(prefix): + cols.append(key) + data.append(value) + + assert all([len(xs) == len(data[0]) for xs in data]), ('mmCIF error: Not all loops are the same length: %s' % cols) + + return [dict(zip(cols, xs)) for xs in zip(*data)] + + +def mmcif_loop_to_dict(prefix, index, parsed_info): + """Extracts loop associated with a prefix from mmCIF data as a dictionary. + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + index: Which item of loop data should serve as the key. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + entries = mmcif_loop_to_list(prefix, parsed_info) + return {entry[index]: entry for entry in entries} + + +def parse_mmcif(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True): + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = PDB.MMCIFParser(QUIET=True) + handle = io.StringIO(mmcif_string) + full_structure = parser.get_structure('', handle) + first_model_structure = _get_first_model(full_structure) + # Extract the _mmcif_dict from the parser, which contains useful fields not + # reflected in the Biopython structure. + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult(None, {(file_id, ''): 'No protein chains found in this file.'}) + seq_start_num = {chain_id: min([monomer.num for monomer in seq]) for chain_id, seq in valid_chains.items()} + + # Loop over the atoms for which we have coordinates. Populate two mappings: + # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used + # the authors / Biopython). + # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). + mmcif_to_author_chain_id = {} + seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + + mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id + + if atom.mmcif_chain_id in valid_chains: + hetflag = ' ' + if atom.hetatm_atom == 'HETATM': + # Water atoms are assigned a special hetflag of W in Biopython. We + # need to do the same, so that this hetflag can be used to fetch + # a residue from the Biopython structure by id. + if atom.residue_name in ('HOH', 'WAT'): + hetflag = 'W' + else: + hetflag = 'H_' + atom.residue_name + insertion_code = atom.insertion_code + if not _is_set(atom.insertion_code): + insertion_code = ' ' + position = ResiduePosition(chain_id=atom.author_chain_id, residue_number=int( + atom.author_seq_num), insertion_code=insertion_code) + seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] + current = seq_to_structure_mappings.get(atom.author_chain_id, {}) + current[seq_idx] = ResidueAtPosition(position=position, + name=atom.residue_name, + is_missing=False, + hetflag=hetflag) + seq_to_structure_mappings[atom.author_chain_id] = current + + # Add missing residue information to seq_to_structure_mappings. + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id.get(chain_id) + current_mapping = seq_to_structure_mappings.get(author_chain) + for idx, monomer in enumerate(seq_info): + if idx not in current_mapping: + current_mapping[idx] = ResidueAtPosition(position=None, + name=monomer.id, + is_missing=True, + hetflag=' ') + + author_chain_to_sequence = {} + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id.get(chain_id) + seq = [] + for monomer in seq_info: + code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') + seq.append(code if len(code) == 1 else 'X') + seq = ''.join(seq) + author_chain_to_sequence[author_chain] = seq + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=first_model_structure, + chain_to_seqres=author_chain_to_sequence, + seqres_to_structure=seq_to_structure_mappings, + raw_string=parsed_info) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +def _get_first_model(structure: PDBSTRUCTURE) -> PDBSTRUCTURE: + """Returns the first model in a Biopython structure.""" + return next(structure.get_models()) + + +_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 + + +def get_release_date(parsed_info: MmCIFDict) -> str: + """Returns the oldest revision date.""" + revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] + return min(revision_dates) + + +def _get_header(parsed_info: MmCIFDict) -> PdbHeader: + """Returns a basic header containing method, release date and resolution.""" + header = {} + + experiments = mmcif_loop_to_list('_exptl.', parsed_info) + header['structure_method'] = ','.join([experiment['_exptl.method'].lower() for experiment in experiments]) + + # Note: The release_date here corresponds to the oldest revision. We prefer to + # use this for dataset filtering over the deposition_date. + if '_pdbx_audit_revision_history.revision_date' in parsed_info: + header['release_date'] = get_release_date(parsed_info) + else: + logging.warning('Could not determine release_date: %s', parsed_info['_entry.id']) + + header['resolution'] = 0.00 + for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution', '_reflns.d_resolution_high'): + if res_key in parsed_info: + try: + raw_resolution = parsed_info[res_key][0] + header['resolution'] = float(raw_resolution) + except ValueError: + logging.warning('Invalid resolution format: %s', parsed_info[res_key]) + + return header + + +def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: + """Returns list of atom sites; contains data not present in the structure.""" + return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension + parsed_info['_atom_site.label_comp_id'], + parsed_info['_atom_site.auth_asym_id'], + parsed_info['_atom_site.label_asym_id'], + parsed_info['_atom_site.auth_seq_id'], + parsed_info['_atom_site.label_seq_id'], + parsed_info['_atom_site.pdbx_PDB_ins_code'], + parsed_info['_atom_site.group_PDB'], + parsed_info['_atom_site.pdbx_PDB_model_num'], + )] + + +def _get_protein_chains(*, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]: + """Extracts polymer information for protein chains only. + + Args: + parsed_info: _mmcif_dict produced by the Biopython parser. + + Returns: + A dict mapping mmcif chain id to a list of Monomers. + """ + # Get polymer information for each entity in the structure. + entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) + + polymers = collections.defaultdict(list) + for entity_poly_seq in entity_poly_seqs: + polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( + Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'], num=int(entity_poly_seq['_entity_poly_seq.num']))) + + # Get chemical compositions. Will allow us to identify which of these polymers + # are proteins. + chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info) + + # Get chains information for each entity. Necessary so that we can return a + # dict keyed on chain id rather than entity. + struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) + + entity_to_mmcif_chains = collections.defaultdict(list) + for struct_asym in struct_asyms: + chain_id = struct_asym['_struct_asym.id'] + entity_id = struct_asym['_struct_asym.entity_id'] + entity_to_mmcif_chains[entity_id].append(chain_id) + + # Identify and return the valid protein chains. + valid_chains = {} + for entity_id, seq_info in polymers.items(): + chain_ids = entity_to_mmcif_chains[entity_id] + + # Reject polymers without any peptide-like components, such as DNA/RNA. + if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type'] for monomer in seq_info]): + for chain_id in chain_ids: + valid_chains[chain_id] = seq_info + return valid_chains + + +def _is_set(data: str) -> bool: + """Returns False if data is a special mmCIF character indicating 'unset'.""" + return data not in ('.', '?') diff --git a/MindSPONGE/applications/research/Grasp/data/permutation.py b/MindSPONGE/applications/research/Grasp/data/permutation.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa59c22d230d7668e812b4a225fad789026abba --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/permutation.py @@ -0,0 +1,835 @@ +import numpy as np +import pickle + +from data import OUTPUT_LABEL_KEYS +from mindsponge1.common.residue_constants import atom_order +from mindsponge1.data.data_transform import pseudo_beta_fn + +GT_KEYS = ["pseudo_beta", "pseudo_beta_mask", "residx_atom14_to_atom37", + "backbone_affine_tensor", "backbone_affine_mask", "rigidgroups_gt_frames", + "rigidgroups_gt_exists", "rigidgroups_alt_gt_frames", "torsion_angles_sin_cos", "chi_mask", + "atom14_gt_positions", "atom14_alt_gt_positions", "atom14_atom_is_ambiguous", "atom14_gt_exists", + "atom14_atom_exists", "atom14_alt_gt_exists", "all_atom_positions", "all_atom_mask", + "true_msa", "bert_mask", + "restype_atom14_bond_lower_bound","restype_atom14_bond_upper_bound","atomtype_radius", + "use_clamped_fape", "filter_by_solution", "asym_mask"] + + +def multi_chain_perm_align_v3(final_atom_positions, input_feats, labels, shuffle_times=3): + + + assert isinstance(labels, list) + + pred_cb_pos, pred_cb_mask = pseudo_beta_fn(input_feats["aatype"][0], final_atom_positions, input_feats["atom37_atom_exists"]) + pred_cb_pos, pred_cb_mask = pred_cb_pos.astype(np.float32), pred_cb_mask.astype(np.float32) + true_cb_poses = [] + true_cb_masks = [] + for label in labels: + true_cb_pose, true_cb_mask = pseudo_beta_fn(label["aatype_per_chain"], label["all_atom_positions"], label["all_atom_mask"]) + true_cb_poses.append(true_cb_pose.astype(np.float32)) + true_cb_masks.append(true_cb_mask.astype(np.float32)) + + unique_asym_ids = np.unique(input_feats["asym_id"]) + + per_asym_residue_index = {} + for cur_asym_id in unique_asym_ids: + asym_mask = (input_feats["asym_id"] == cur_asym_id).astype(bool) + per_asym_residue_index[int(cur_asym_id)] = input_feats["residue_index"][asym_mask] + + + + unique_entity_ids = np.unique(input_feats["entity_id"]) + entity_2_asym_list = {} + for cur_ent_id in unique_entity_ids: + ent_mask = input_feats["entity_id"] == cur_ent_id + cur_asym_id = np.unique(input_feats["asym_id"][ent_mask]) + entity_2_asym_list[int(cur_ent_id)] = cur_asym_id + + asym_2_entity_list = {} + for ent, asys in entity_2_asym_list.items(): + for asy in asys: + asym_2_entity_list[asy] = ent + + # find anchor pred chain + anchor_gt_asym, anchor_pred_asym = get_anchor_candidates( + input_feats, per_asym_residue_index, true_cb_masks + ) + anchor_gt_idxs = entity_2_asym_list[asym_2_entity_list[anchor_gt_asym]] + + max_chain_length = 0 + for cur_asym_id in anchor_pred_asym: + asym_mask = (input_feats["asym_id"] == cur_asym_id).astype(bool) + if asym_mask.sum() > max_chain_length: + max_chain_length = asym_mask.sum() + final_asym_mask = asym_mask + anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)] + + # find optimal transforms + best_rmsd = 1e9 + best_r, best_x = None, None + for anchor_gt_idx in anchor_gt_idxs: + anchor_gt_idx = anchor_gt_idx - 1 + anchor_true_pos = true_cb_poses[anchor_gt_idx][anchor_residue_idx] + anchor_pred_pos = pred_cb_pos[final_asym_mask] + anchor_true_mask = true_cb_masks[anchor_gt_idx][anchor_residue_idx] + anchor_pred_mask = pred_cb_mask[final_asym_mask] + r, x = get_optimal_transform( + anchor_true_pos, + anchor_pred_pos, + (anchor_true_mask * anchor_pred_mask).astype(bool), + ) + + aligned_anchor_true_pos = anchor_true_pos @ r + x + rmsd = compute_rmsd(aligned_anchor_true_pos, anchor_pred_pos, anchor_true_mask.astype(np.int32)) + if rmsd < best_rmsd: + best_rmsd = rmsd + best_r = r + best_x = x + + best_labels = None + aligned_true_cb_poses = [cb @ best_r + best_x for cb in true_cb_poses] # apply transforms + + # greedy align + best_rmsd = 1e9 + for i in range(shuffle_times): + np.random.seed(i) + shuffle_idx = np.random.permutation(unique_asym_ids.shape[0]) + np.random.seed() + shuffled_asym_ids = unique_asym_ids[shuffle_idx] + align = greedy_align( + input_feats, + per_asym_residue_index, + shuffled_asym_ids, + entity_2_asym_list, + pred_cb_pos, + pred_cb_mask, + aligned_true_cb_poses, + true_cb_masks, + ) + + merged_labels = merge_labels( + input_feats, + per_asym_residue_index, + labels, + align, + ) + + merged_ca_pose, merged_ca_mask = pseudo_beta_fn(merged_labels["aatype_per_chain"], merged_labels["all_atom_positions"], merged_labels["all_atom_mask"]) + + rmsd = kabsch_rmsd( + merged_ca_pose @ best_r + best_x, + pred_cb_pos, + (pred_cb_mask * merged_ca_mask).astype(bool), + ) + + if rmsd < best_rmsd: + best_rmsd = rmsd + best_labels = merged_labels + + return best_labels + + +def multi_chain_perm_align_v2(final_atom_positions, input_feats, labels, shuffle_times=3): + # print(input_feats["asym_id"]) + # print(input_feats["residue_index"]) + # print(input_feats["entity_id"]) + # print(input_feats["num_sym"]) + + + assert isinstance(labels, list) + + # ca_idx = atom_order["CA"] + # pred_ca_pos = final_atom_positions[..., ca_idx, :].astype(np.float32) # [bsz, nres, 3] + # pred_ca_mask = input_feats["atom37_atom_exists"][..., ca_idx].astype(np.float32) # [bsz, nres] + # # import time + # # time.sleep(10000) + # true_ca_poses = [l["all_atom_positions"][..., ca_idx, :].astype(np.float32) for l in labels] # list([nres, 3]) + # true_ca_masks = [l["all_atom_mask"][..., ca_idx].astype(np.float32) for l in labels] # list([nres,]) + + + pred_cb_pos, pred_cb_mask = pseudo_beta_fn(input_feats["aatype"][0], final_atom_positions, input_feats["atom37_atom_exists"]) + pred_cb_pos, pred_cb_mask = pred_cb_pos.astype(np.float32), pred_cb_mask.astype(np.float32) + true_cb_poses = [] + true_cb_masks = [] + for label in labels: + true_cb_pose, true_cb_mask = pseudo_beta_fn(label["aatype_per_chain"], label["all_atom_positions"], label["all_atom_mask"]) + true_cb_poses.append(true_cb_pose.astype(np.float32)) + true_cb_masks.append(true_cb_mask.astype(np.float32)) + + unique_asym_ids = np.unique(input_feats["asym_id"]) + + per_asym_residue_index = {} + for cur_asym_id in unique_asym_ids: + asym_mask = (input_feats["asym_id"] == cur_asym_id).astype(bool) + per_asym_residue_index[int(cur_asym_id)] = input_feats["residue_index"][asym_mask] + + anchor_gt_asym, anchor_pred_asym = get_anchor_candidates( + input_feats, per_asym_residue_index, true_cb_masks + ) + anchor_gt_idx = int(anchor_gt_asym) - 1 + + + unique_entity_ids = np.unique(input_feats["entity_id"]) + entity_2_asym_list = {} + for cur_ent_id in unique_entity_ids: + ent_mask = input_feats["entity_id"] == cur_ent_id + cur_asym_id = np.unique(input_feats["asym_id"][ent_mask]) + entity_2_asym_list[int(cur_ent_id)] = cur_asym_id + + # find optimal transforms + best_rmsd = 1e9 + best_r, best_x = None, None + for cur_asym_id in anchor_pred_asym: + asym_mask = (input_feats["asym_id"] == cur_asym_id).astype(bool) + anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)] + anchor_true_pos = true_cb_poses[anchor_gt_idx][anchor_residue_idx] + anchor_pred_pos = pred_cb_pos[asym_mask] + anchor_true_mask = true_cb_masks[anchor_gt_idx][anchor_residue_idx] + anchor_pred_mask = pred_cb_mask[asym_mask] + r, x = get_optimal_transform( + anchor_true_pos, + anchor_pred_pos, + (anchor_true_mask * anchor_pred_mask).astype(bool), + ) + + aligned_anchor_true_pos = anchor_true_pos @ r + x + rmsd = compute_rmsd(aligned_anchor_true_pos, anchor_pred_pos, anchor_true_mask.astype(np.int32)) + if rmsd < best_rmsd: + best_rmsd = rmsd + best_r = r + best_x = x + + best_labels = None + aligned_true_cb_poses = [cb @ best_r + best_x for cb in true_cb_poses] # apply transforms + + # greedy align + best_rmsd = 1e9 + for i in range(shuffle_times): + np.random.seed(i) + shuffle_idx = np.random.permutation(unique_asym_ids.shape[0]) + np.random.seed() + shuffled_asym_ids = unique_asym_ids[shuffle_idx] + align = greedy_align( + input_feats, + per_asym_residue_index, + shuffled_asym_ids, + entity_2_asym_list, + pred_cb_pos, + pred_cb_mask, + aligned_true_cb_poses, + true_cb_masks, + ) + + merged_labels = merge_labels( + input_feats, + per_asym_residue_index, + labels, + align, + ) + + merged_ca_pose, merged_ca_mask = pseudo_beta_fn(merged_labels["aatype_per_chain"], merged_labels["all_atom_positions"], merged_labels["all_atom_mask"]) + + rmsd = kabsch_rmsd( + merged_ca_pose @ best_r + best_x, + pred_cb_pos, + (pred_cb_mask * merged_ca_mask).astype(bool), + ) + + if rmsd < best_rmsd: + best_rmsd = rmsd + best_labels = merged_labels + + # print("multi_chain_perm_align", best_rmsd) + return best_labels + + +def multi_chain_perm_align_v1(final_atom_positions, input_feats, labels, shuffle_times=2): + + + assert isinstance(labels, list) + + pred_ca_pos, pred_ca_mask = pseudo_beta_fn(input_feats["aatype"][0], final_atom_positions, input_feats["atom37_atom_exists"]) + pred_ca_pos, pred_ca_mask = pred_ca_pos.astype(np.float32), pred_ca_mask.astype(np.float32) + true_ca_poses = [] + true_ca_masks = [] + for label in labels: + true_ca_pose, true_ca_mask = pseudo_beta_fn(label["aatype_per_chain"], label["all_atom_positions"], label["all_atom_mask"]) + true_ca_poses.append(true_ca_pose.astype(np.float32)) + true_ca_masks.append(true_ca_mask.astype(np.float32)) + + unique_asym_ids = np.unique(input_feats["asym_id"]) + + per_asym_residue_index = {} + for cur_asym_id in unique_asym_ids: + asym_mask = (input_feats["asym_id"] == cur_asym_id).astype(bool) + per_asym_residue_index[int(cur_asym_id)] = input_feats["residue_index"][asym_mask] + + anchor_gt_asym, anchor_pred_asym = get_anchor_candidates( + input_feats, per_asym_residue_index, true_ca_masks + ) + anchor_gt_idx = int(anchor_gt_asym) - 1 + + best_rmsd = 1e9 + best_labels = None + + unique_entity_ids = np.unique(input_feats["entity_id"]) + entity_2_asym_list = {} + for cur_ent_id in unique_entity_ids: + ent_mask = input_feats["entity_id"] == cur_ent_id + cur_asym_id = np.unique(input_feats["asym_id"][ent_mask]) + entity_2_asym_list[int(cur_ent_id)] = cur_asym_id + + + for cur_asym_id in anchor_pred_asym: + asym_mask = (input_feats["asym_id"] == cur_asym_id).astype(bool) + anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)] + + + anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx] + anchor_pred_pos = pred_ca_pos[asym_mask] + anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx] + anchor_pred_mask = pred_ca_mask[asym_mask] + r, x = get_optimal_transform( + anchor_true_pos, + anchor_pred_pos, + (anchor_true_mask * anchor_pred_mask).astype(bool), + ) + + + + aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms + + for i in range(shuffle_times): + np.random.seed(i) + shuffle_idx = np.random.permutation(unique_asym_ids.shape[0]) + np.random.seed() + shuffled_asym_ids = unique_asym_ids[shuffle_idx] + align = greedy_align( + input_feats, + per_asym_residue_index, + shuffled_asym_ids, + entity_2_asym_list, + pred_ca_pos, + pred_ca_mask, + aligned_true_ca_poses, + true_ca_masks, + ) + merged_labels = merge_labels( + input_feats, + per_asym_residue_index, + labels, + align, + ) + + merged_ca_pose, merged_ca_mask = pseudo_beta_fn(merged_labels["aatype_per_chain"], merged_labels["all_atom_positions"], merged_labels["all_atom_mask"]) + + rmsd = kabsch_rmsd( + merged_ca_pose @ r + x, + pred_ca_pos, + (pred_ca_mask * merged_ca_mask).astype(bool), + ) + + if rmsd < best_rmsd: + best_rmsd = rmsd + best_labels = merged_labels + + return best_labels + + +def get_anchor_candidates(input_feats, per_asym_residue_index, true_masks): + def find_by_num_sym(min_num_sym): + best_len = -1 + best_gt_asym = None + asym_ids = np.unique(input_feats["asym_id"][input_feats["num_sym"] == min_num_sym]) + for cur_asym_id in asym_ids: + assert cur_asym_id > 0 + cur_residue_index = per_asym_residue_index[int(cur_asym_id)] + j = int(cur_asym_id - 1) + cur_true_mask = true_masks[j][cur_residue_index] + cur_len = cur_true_mask.sum() + if cur_len > best_len: + best_len = cur_len + best_gt_asym = cur_asym_id + return best_gt_asym, best_len + + sorted_num_sym = np.sort(input_feats["num_sym"][input_feats["num_sym"] > 0]) + best_gt_asym = None + best_len = -1 + for cur_num_sym in sorted_num_sym: + if cur_num_sym <= 0: + continue + cur_gt_sym, cur_len = find_by_num_sym(cur_num_sym) + if cur_len > best_len: + best_len = cur_len + best_gt_asym = cur_gt_sym + if best_len >= 3: + break + best_entity = input_feats["entity_id"][input_feats["asym_id"] == best_gt_asym][0] + best_pred_asym = np.unique(input_feats["asym_id"][input_feats["entity_id"] == best_entity]) + return best_gt_asym, best_pred_asym + + +def get_optimal_transform(src_atoms, tgt_atoms, mask = None): + assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) + assert src_atoms.shape[-1] == 3 + if mask is not None: + assert mask.dtype == bool + assert mask.shape[-1] == src_atoms.shape[-2] + if mask.sum() == 0: + src_atoms = np.zeros((1, 3)).astype(np.float32) + tgt_atoms = src_atoms + else: + src_atoms = src_atoms[mask, :] + tgt_atoms = tgt_atoms[mask, :] + src_center = src_atoms.mean(-2, keepdims=True) + tgt_center = tgt_atoms.mean(-2, keepdims=True) + + r = kabsch_rotation(src_atoms - src_center, tgt_atoms - tgt_center) + x = tgt_center - src_center @ r + return r, x + + +def kabsch_rotation(P, Q): + """ + Using the Kabsch algorithm with two sets of paired point P and Q, centered + around the centroid. Each vector set is represented as an NxD + matrix, where D is the the dimension of the space. + The algorithm works in three steps: + - a centroid translation of P and Q (assumed done before this function + call) + - the computation of a covariance matrix C + - computation of the optimal rotation matrix U + For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm + Parameters + ---------- + P : array + (N,D) matrix, where N is points and D is dimension. + Q : array + (N,D) matrix, where N is points and D is dimension. + Returns + ------- + U : matrix + Rotation matrix (D,D) + """ + + # Computation of the covariance matrix + C = P.transpose(-1, -2) @ Q + # Computation of the optimal rotation matrix + # This can be done using singular value decomposition (SVD) + # Getting the sign of the det(V)*(W) to decide + # whether we need to correct our rotation matrix to ensure a + # right-handed coordinate system. + # And finally calculating the optimal rotation matrix U + # see http://en.wikipedia.org/wiki/Kabsch_algorithm + V, _, W = np.linalg.svd(C) + d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0 + + if d: + V[:, -1] = -V[:, -1] + + # Create Rotation matrix U + U = V @ W + return U + + +def greedy_align( + input_feats, + per_asym_residue_index, + unique_asym_ids, + entity_2_asym_list, + pred_ca_pos, + pred_ca_mask, + true_ca_poses, + true_ca_masks, + ): + used = [False for _ in range(len(true_ca_poses))] + align = [] + for cur_asym_id in unique_asym_ids: + # skip padding + if cur_asym_id == 0: + continue + i = int(cur_asym_id - 1) + asym_mask = input_feats["asym_id"] == cur_asym_id + num_sym = input_feats["num_sym"][asym_mask][0] + # don't need to align + if (num_sym) == 1: + align.append((i, i)) + assert used[i] == False + used[i] = True + continue + cur_entity_ids = input_feats["entity_id"][asym_mask][0] + best_rmsd = 1e10 + best_idx = None + cur_asym_list = entity_2_asym_list[int(cur_entity_ids)] + cur_residue_index = per_asym_residue_index[int(cur_asym_id)] + cur_pred_pos = pred_ca_pos[asym_mask] + cur_pred_mask = pred_ca_mask[asym_mask] + for next_asym_id in cur_asym_list: + if next_asym_id == 0: + continue + j = int(next_asym_id - 1) + if not used[j]: # posesible candidate + cropped_pos = true_ca_poses[j][cur_residue_index] + mask = true_ca_masks[j][cur_residue_index] + rmsd = compute_rmsd( + cropped_pos, cur_pred_pos, (cur_pred_mask * mask).astype(bool) + ) + if rmsd < best_rmsd: + best_rmsd = rmsd + best_idx = j + + assert best_idx is not None + used[best_idx] = True + align.append((i, best_idx)) + + return align + + +def compute_rmsd(true_atom_pos, pred_atom_pos, atom_mask = None, eps = 1e-6,): + # shape check + sq_diff = np.square(true_atom_pos - pred_atom_pos).sum(axis=1, keepdims=False) + if len(sq_diff) == 1: + return 1e8 + if atom_mask is not None: + sq_diff = sq_diff[atom_mask] + msd = np.mean(sq_diff) + msd = np.nan_to_num(msd, nan=1e8) + return np.sqrt(msd + eps) + + +def merge_labels(input_feats, per_asym_residue_index, labels, align): + """ + input_feats: + labels: list of label dicts, each with shape [nk, *] + align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym. + """ + num_res = input_feats["msa_mask"].shape[-1] + outs = {} + for k, v in labels[0].items(): + if k in [ + "resolution", + ]: + continue + cur_out = {} + for i, j in align: + label = labels[j][k] + # to 1-based + cur_residue_index = per_asym_residue_index[i + 1] + cur_out[i] = label[cur_residue_index] + cur_out = [x[1] for x in sorted(cur_out.items())] + new_v = np.concatenate(cur_out, axis=0) + merged_nres = new_v.shape[0] + assert ( + merged_nres <= num_res + ), f"bad merged num res: {merged_nres} > {num_res}. something is wrong." + if merged_nres < num_res: # must pad + pad_dim = new_v.shape[1:] + pad_v = np.zeros((num_res - merged_nres, *pad_dim)).astype(new_v.dtype) + new_v = np.concatenate((new_v, pad_v), axis=0) + outs[k] = new_v + return outs + + +def kabsch_rmsd(true_atom_pos, pred_atom_pos, atom_mask,): + r, x = get_optimal_transform( + true_atom_pos, + pred_atom_pos, + atom_mask, + ) + aligned_true_atom_pos = true_atom_pos @ r + x + return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask) + + +def placeholder_data_genenrator(num_res, num_msa): + + + data = {} + data["atomtype_radius"] = np.zeros((3, )).astype(np.float16) + data["restype_atom14_bond_lower_bound"] = np.zeros((21, 14, 14)).astype(np.float16) + data["restype_atom14_bond_upper_bound"] = np.zeros((21, 14, 14)).astype(np.float16) + data["use_clamped_fape"] = np.zeros((1,)).astype(np.float16) + data["filter_by_solution"] = np.array(0).astype(np.float16) + + data["prot_name_index"] = np.zeros((1, )).astype(np.float16) + + data["seq_mask"] = np.zeros((num_res,)).astype(np.float16) + data["aatype"] = np.zeros((num_res,)).astype(np.int32) + data["residue_index"] = np.zeros((num_res,)).astype(np.int32) + data["true_msa"] = np.zeros((num_msa, num_res)).astype(np.int32) + data["bert_mask"] = np.zeros((num_msa, num_res)).astype(np.int32) + + + + data["pseudo_beta"] = np.zeros((num_res, 3)).astype(np.float16) + data["pseudo_beta_mask"] = np.zeros((num_res,)).astype(np.float16) + data["all_atom_mask"] = np.zeros((num_res, 37)).astype(np.float16) + data["atom37_atom_exists"] = np.zeros((num_res, 37)).astype(np.float16) + data["residx_atom14_to_atom37"] = np.zeros((num_res, 14)).astype(np.int32) + data["atom14_atom_exists"] = np.zeros((num_res, 14)).astype(np.float16) + data["backbone_affine_tensor"] = np.zeros((num_res, 7)).astype(np.float16) + data["backbone_affine_mask"] = np.zeros((num_res,)).astype(np.float16) + + data["atom14_gt_positions"] = np.zeros((num_res, 14, 3)).astype(np.float16) + data["atom14_alt_gt_positions"] = np.zeros((num_res, 14, 3)).astype(np.float16) + data["atom14_atom_is_ambiguous"] = np.zeros((num_res, 14)).astype(np.float16) + data["atom14_gt_exists"] = np.zeros((num_res, 14)).astype(np.float16) + data["atom14_alt_gt_exists"] = np.zeros((num_res, 14)).astype(np.float16) + + data["all_atom_positions"] = np.zeros((num_res, 37, 3)).astype(np.float16) + data["rigidgroups_gt_frames"] = np.zeros((num_res, 8, 12)).astype(np.float16) + data["rigidgroups_gt_exists"] = np.zeros((num_res, 8)).astype(np.float16) + data["rigidgroups_alt_gt_frames"] = np.zeros((num_res, 8, 12)).astype(np.float16) + data["torsion_angles_sin_cos"] = np.zeros((num_res, 4, 2)).astype(np.float16) + data["chi_mask"] = np.zeros((num_res, 4)).astype(np.float16) + + data["asym_mask"] = np.zeros((256, num_res)).astype(np.float16) + + gt_fake = [data[key] for key in GT_KEYS] + + return gt_fake + + +def ground_truth_generator(input_data, atom37_position_pred, max_recycle): + def extract_labels(d): + all_labels = [] + for cur_chain_index in range(np.max(d["chain_index"]) + 1): + all_label = {} + for key in OUTPUT_LABEL_KEYS: + all_label[key] = d[key][d["chain_index"] == cur_chain_index] + all_labels.append(all_label) + return all_labels + all_labels = extract_labels(input_data) + # for i, all_label in enumerate(all_labels): + # print("\n\n\n===============", i) + # for key, value in all_label.items(): + # print(key, value.shape, value.dtype) + + input_data_single = {} + for key, value in input_data.items(): + if len(value.shape) > 0 and value.shape[0] == input_data["msa_feat"].shape[0]: + value = value[max_recycle-1] + + input_data_single[key] = value + + asym_id = input_data_single["asym_id"] + asym_type = np.arange(1, np.max(asym_id) + 1) + asym_mask = (asym_id[None, :] == asym_type[:, None]).astype(np.float16) # [NC, NR] + # print(asym_mask) + asym_mask = np.pad(asym_mask, ((0, 256 - asym_mask.shape[0]), (0, 0))).astype(np.float16) + # print(asym_mask[:4]) + # print(asym_mask.shape) + input_data_single["asym_mask"] = asym_mask + + final_labels = multi_chain_perm_align_v1(atom37_position_pred, + input_data_single, + all_labels, + shuffle_times=4) + # for key, value in final_labels.items(): + # print(key, value.shape, value.dtype) + + final_labels_keys = list(final_labels.keys()) + + # print(set(GT_KEYS) - set(final_labels_keys)) + # {'bert_mask', 'true_msa', 'restype_atom14_bond_lower_bound', 'restype_atom14_bond_upper_bound', 'filter_by_solution', 'use_clamped_fape', 'atomtype_radius', } + + # print(set(final_labels_keys) - set(GT_KEYS)) + # {'chain_index', 'atom37_atom_exists_per_chain', 'aatype_per_chain'} + + # print(set(GT_KEYS).intersection(set(final_labels_keys))) + # {'atom14_alt_gt_exists', 'pseudo_beta_mask', 'all_atom_mask', 'atom14_gt_exists', 'chi_mask', 'atom14_atom_is_ambiguous', 'backbone_affine_tensor', 'pseudo_beta', 'rigidgroups_gt_frames', 'rigidgroups_gt_exists', 'all_atom_positions', 'atom14_alt_gt_positions', 'atom14_gt_positions', 'backbone_affine_mask', 'residx_atom14_to_atom37', 'rigidgroups_alt_gt_frames', 'torsion_angles_sin_cos', 'atom14_atom_exists'} + + input_keys = ['restype_atom14_bond_lower_bound', 'restype_atom14_bond_upper_bound', + 'filter_by_solution', 'use_clamped_fape', 'atomtype_radius'] + \ + ['bert_mask', 'true_msa',"asym_mask"] + + gt_keys_useful = set(GT_KEYS).intersection(set(final_labels_keys)) + + # print("\n\n\n\n final gt data====================") + final_gt_data = [] + for key in GT_KEYS: + if key in input_keys: + value = input_data_single[key] + else: + value = final_labels[key] + + final_gt_data.append(value) + # print(key, value.shape, value.dtype) + + return final_gt_data + + +def ground_truth_generator_v2(input_data, atom37_position_pred): + def extract_labels(d): + all_labels = [] + for cur_chain_index in range(np.max(d["chain_index"]) + 1): + all_label = {} + for key in OUTPUT_LABEL_KEYS: + all_label[key] = d[key][d["chain_index"] == cur_chain_index] + all_labels.append(all_label) + return all_labels + all_labels = extract_labels(input_data) + # for i, all_label in enumerate(all_labels): + # print("\n\n\n===============", i) + # for key, value in all_label.items(): + # print(key, value.shape, value.dtype) + + input_data_single = input_data + + asym_id = input_data_single["asym_id"] + asym_type = np.arange(1, np.max(asym_id) + 1) + asym_mask = (asym_id[None, :] == asym_type[:, None]).astype(np.float16) # [NC, NR] + # print(asym_mask) + asym_mask = np.pad(asym_mask, ((0, 256 - asym_mask.shape[0]), (0, 0))).astype(np.float16) + # print(asym_mask[:4]) + # print(asym_mask.shape) + input_data_single["asym_mask"] = asym_mask + + final_labels = multi_chain_perm_align_v1(atom37_position_pred, + input_data_single, + all_labels, + shuffle_times=4) + # for key, value in final_labels.items(): + # print(key, value.shape, value.dtype) + + final_labels_keys = list(final_labels.keys()) + + # print(set(GT_KEYS) - set(final_labels_keys)) + # {'bert_mask', 'true_msa', 'restype_atom14_bond_lower_bound', 'restype_atom14_bond_upper_bound', 'filter_by_solution', 'use_clamped_fape', 'atomtype_radius', } + + # print(set(final_labels_keys) - set(GT_KEYS)) + # {'chain_index', 'atom37_atom_exists_per_chain', 'aatype_per_chain'} + + # print(set(GT_KEYS).intersection(set(final_labels_keys))) + # {'atom14_alt_gt_exists', 'pseudo_beta_mask', 'all_atom_mask', 'atom14_gt_exists', 'chi_mask', 'atom14_atom_is_ambiguous', 'backbone_affine_tensor', 'pseudo_beta', 'rigidgroups_gt_frames', 'rigidgroups_gt_exists', 'all_atom_positions', 'atom14_alt_gt_positions', 'atom14_gt_positions', 'backbone_affine_mask', 'residx_atom14_to_atom37', 'rigidgroups_alt_gt_frames', 'torsion_angles_sin_cos', 'atom14_atom_exists'} + + input_keys = ['restype_atom14_bond_lower_bound', 'restype_atom14_bond_upper_bound', + 'filter_by_solution', 'use_clamped_fape', 'atomtype_radius'] + \ + ['bert_mask', 'true_msa',"asym_mask"] + + gt_keys_useful = set(GT_KEYS).intersection(set(final_labels_keys)) + + # print("\n\n\n\n final gt data====================") + final_gt_data = [] + for key in GT_KEYS: + if key in input_keys: + value = input_data_single[key] + else: + value = final_labels[key] + + final_gt_data.append(value) + # print(key, value.shape, value.dtype) + + return final_gt_data + +''' + + +==========================feature +aatype (384,) int64 +residue_index (384,) int64 +seq_length () int64 +msa_chains (124, 1) float64 +template_aatype (4, 384) int64 +template_all_atom_mask (4, 384, 37) float32 +template_all_atom_positions (4, 384, 37, 3) float32 +all_atom_positions (384, 37, 3) float32 +all_atom_mask (384, 37) float32 +resolution () float32 +asym_id (384,) float64 +sym_id (384,) float64 +entity_id (384,) float64 +num_sym (384,) float64 +assembly_num_chains (1,) int64 +cluster_bias_mask (124,) float32 +bert_mask (124, 384) float32 +msa_mask (124, 384) float32 +asym_len (5,) int64 +num_recycling_iters () int64 +use_clamped_fape () int64 +is_distillation () int64 +seq_mask (384,) float32 +msa_row_mask (124,) float32 +template_mask (4,) float32 +template_pseudo_beta (4, 384, 3) float32 +template_pseudo_beta_mask (4, 384) float32 +template_torsion_angles_sin_cos (4, 384, 7, 2) float32 +template_alt_torsion_angles_sin_cos (4, 384, 7, 2) float32 +template_torsion_angles_mask (4, 384, 7) float32 +residx_atom14_to_atom37 (384, 14) int64 +residx_atom37_to_atom14 (384, 37) int64 +atom14_atom_exists (384, 14) float32 +atom37_atom_exists (384, 37) float32 +target_feat (384, 22) float32 +extra_msa (1152, 384) int64 +extra_msa_mask (1152, 384) float32 +extra_msa_row_mask (1152,) float32 +true_msa (124, 384) int64 +msa_feat (124, 384, 49) float32 +extra_msa_has_deletion (1152, 384) float32 +extra_msa_deletion_value (1152, 384) float32 + + + + +==========================labels +aatype (216,) int64 +all_atom_positions (216, 37, 3) float32 +all_atom_mask (216, 37) float32 +resolution (1,) float32 +residx_atom14_to_atom37 (216, 14) int64 +residx_atom37_to_atom14 (216, 37) int64 +atom14_atom_exists (216, 14) float32 +atom37_atom_exists (216, 37) float32 +atom14_gt_exists (216, 14) float32 +atom14_gt_positions (216, 14, 3) float32 +atom14_alt_gt_positions (216, 14, 3) float32 +atom14_alt_gt_exists (216, 14) float32 +atom14_atom_is_ambiguous (216, 14) float32 +rigidgroups_gt_frames (216, 8, 4, 4) float32 +rigidgroups_gt_exists (216, 8) float32 +rigidgroups_group_exists (216, 8) float32 +rigidgroups_group_is_ambiguous (216, 8) float32 +rigidgroups_alt_gt_frames (216, 8, 4, 4) float32 +torsion_angles_sin_cos (216, 7, 2) float32 +alt_torsion_angles_sin_cos (216, 7, 2) float32 +torsion_angles_mask (216, 7) float32 +pseudo_beta (216, 3) float32 +pseudo_beta_mask (216,) float32 +true_frame_tensor (216, 4, 4) float32 +frame_mask (216,) float32 +chi_angles_sin_cos (216, 4, 2) float32 +chi_mask (216, 4) float32 + + + + +==========================output +aatype (384,) int64 +all_atom_positions (384, 37, 3) float32 +all_atom_mask (384, 37) float32 +residx_atom14_to_atom37 (384, 14) int64 +residx_atom37_to_atom14 (384, 37) int64 +atom14_atom_exists (384, 14) float32 +atom37_atom_exists (384, 37) float32 +atom14_gt_exists (384, 14) float32 +atom14_gt_positions (384, 14, 3) float32 +atom14_alt_gt_positions (384, 14, 3) float32 +atom14_alt_gt_exists (384, 14) float32 +atom14_atom_is_ambiguous (384, 14) float32 +rigidgroups_gt_frames (384, 8, 4, 4) float32 +rigidgroups_gt_exists (384, 8) float32 +rigidgroups_group_exists (384, 8) float32 +rigidgroups_group_is_ambiguous (384, 8) float32 +rigidgroups_alt_gt_frames (384, 8, 4, 4) float32 +torsion_angles_sin_cos (384, 7, 2) float32 +alt_torsion_angles_sin_cos (384, 7, 2) float32 +torsion_angles_mask (384, 7) float32 +pseudo_beta (384, 3) float32 +pseudo_beta_mask (384,) float32 +true_frame_tensor (384, 4, 4) float32 +frame_mask (384,) float32 +chi_angles_sin_cos (384, 4, 2) float32 +chi_mask (384, 4) float32 + + +''' \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/data/preprocess.py b/MindSPONGE/applications/research/Grasp/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..4d575d08d841308fe937dd7b403bbb11a87eb927 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/preprocess.py @@ -0,0 +1,1063 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""data process""" +import numpy as np +import pickle +from mindsponge1.data.data_transform import one_hot, correct_msa_restypes, randomly_replace_msa_with_unknown, \ + fix_templates_aatype, pseudo_beta_fn, make_atom14_masks, make_msa_feat_v2, make_extra_msa_feat, \ + block_delete_msa_indices, sample_msa, sample_msa_v2, make_masked_msa, make_masked_msa_v2, \ + nearest_neighbor_clusters, nearest_neighbor_clusters_v2, summarize_clusters, crop_extra_msa, \ + make_msa_feat, random_crop_to_size, generate_random_sample, atom37_to_torsion_angles +from mindsponge1.common.residue_constants import atom_type_num + +from .utils import numpy_seed +from .multimer_process import get_spatial_crop_idx_v2, get_spatial_crop_idx, get_contiguous_crop_idx, \ + apply_crop_idx, select_feat, make_fixed_size, map_fn, make_pseudo_beta +from utils_xyh import show_npdict +import pickle + +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' +NUM_SEQ = "length msa placeholder" +NUM_NOISE = 'num noise placeholder' +NUM_LATENT_DIM = "num latent placeholder" +_MSA_FEATURE_NAMES = ['msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa', 'msa_input'] + +FEATURES = { + # Static features of a protein sequence + "aatype": (np.float32, [NUM_RES, 21]), + "between_segment_residues": (np.int64, [NUM_RES, 1]), + "deletion_matrix": (np.float32, [NUM_SEQ, NUM_RES, 1]), + "msa": (np.int64, [NUM_SEQ, NUM_RES, 1]), + "num_alignments": (np.int64, [NUM_RES, 1]), + "residue_index": (np.int64, [NUM_RES, 1]), + "seq_length": (np.int64, [NUM_RES, 1]), + "all_atom_positions": (np.float32, [NUM_RES, atom_type_num, 3]), + "all_atom_mask": (np.int64, [NUM_RES, atom_type_num]), + "resolution": (np.float32, [1]), + "template_domain_names": (str, [NUM_TEMPLATES]), + "template_sum_probs": (np.float32, [NUM_TEMPLATES, 1]), + "template_aatype": (np.float32, [NUM_TEMPLATES, NUM_RES, 22]), + "template_all_atom_positions": (np.float32, [NUM_TEMPLATES, NUM_RES, atom_type_num, 3]), + "template_all_atom_masks": (np.float32, [NUM_TEMPLATES, NUM_RES, atom_type_num, 1]), + "atom14_atom_exists": (np.float32, [NUM_RES, 14]), + "atom14_gt_exists": (np.float32, [NUM_RES, 14]), + "atom14_gt_positions": (np.float32, [NUM_RES, 14, 3]), + "residx_atom14_to_atom37": (np.float32, [NUM_RES, 14]), + "residx_atom37_to_atom14": (np.float32, [NUM_RES, 37]), + "atom37_atom_exists": (np.float32, [NUM_RES, 37]), + "atom14_alt_gt_positions": (np.float32, [NUM_RES, 14, 3]), + "atom14_alt_gt_exists": (np.float32, [NUM_RES, 14]), + "atom14_atom_is_ambiguous": (np.float32, [NUM_RES, 14]), + "rigidgroups_gt_frames": (np.float32, [NUM_RES, 8, 12]), + "rigidgroups_gt_exists": (np.float32, [NUM_RES, 8]), + "rigidgroups_group_exists": (np.float32, [NUM_RES, 8]), + "rigidgroups_group_is_ambiguous": (np.float32, [NUM_RES, 8]), + "rigidgroups_alt_gt_frames": (np.float32, [NUM_RES, 8, 12]), + "backbone_affine_tensor": (np.float32, [NUM_RES, 7]), + "torsion_angles_sin_cos": (np.float32, [NUM_RES, 4, 2]), + "torsion_angles_mask": (np.float32, [NUM_RES, 7]), + "pseudo_beta": (np.float32, [NUM_RES, 3]), + "pseudo_beta_mask": (np.float32, [NUM_RES]), + "chi_mask": (np.float32, [NUM_RES, 4]), + "backbone_affine_mask": (np.float32, [NUM_RES]), +} + +feature_list = { + 'aatype': [NUM_RES], + 'all_atom_mask': [NUM_RES, None], + 'all_atom_positions': [NUM_RES, None, None], + 'alt_chi_angles': [NUM_RES, None], + 'atom14_alt_gt_exists': [NUM_RES, None], + 'atom14_alt_gt_positions': [NUM_RES, None, None], + 'atom14_atom_exists': [NUM_RES, None], + 'atom14_atom_is_ambiguous': [NUM_RES, None], + 'atom14_gt_exists': [NUM_RES, None], + 'atom14_gt_positions': [NUM_RES, None, None], + 'atom37_atom_exists': [NUM_RES, None], + 'backbone_affine_mask': [NUM_RES], + 'backbone_affine_tensor': [NUM_RES, None], + 'bert_mask': [NUM_MSA_SEQ, NUM_RES], + 'chi_angles': [NUM_RES, None], + 'chi_mask': [NUM_RES, None], + 'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_row_mask': [NUM_EXTRA_SEQ], + 'is_distillation': [], + 'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], + 'msa_mask': [NUM_MSA_SEQ, NUM_RES], + 'msa_row_mask': [NUM_MSA_SEQ], + 'pseudo_beta': [NUM_RES, None], + 'pseudo_beta_mask': [NUM_RES], + 'random_crop_to_size_seed': [None], + 'residue_index': [NUM_RES], + 'residx_atom14_to_atom37': [NUM_RES, None], + 'residx_atom37_to_atom14': [NUM_RES, None], + 'resolution': [], + 'rigidgroups_alt_gt_frames': [NUM_RES, None, None], + 'rigidgroups_group_exists': [NUM_RES, None], + 'rigidgroups_group_is_ambiguous': [NUM_RES, None], + 'rigidgroups_gt_exists': [NUM_RES, None], + 'rigidgroups_gt_frames': [NUM_RES, None, None], + 'seq_length': [], + 'seq_mask': [NUM_RES], + 'target_feat': [NUM_RES, None], + 'template_aatype': [NUM_TEMPLATES, NUM_RES], + 'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None], + 'template_all_atom_positions': [ + NUM_TEMPLATES, NUM_RES, None, None], + 'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], + 'template_backbone_affine_tensor': [ + NUM_TEMPLATES, NUM_RES, None], + 'template_mask': [NUM_TEMPLATES], + 'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], + 'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], + 'template_sum_probs': [NUM_TEMPLATES, None], + 'true_msa': [NUM_MSA_SEQ, NUM_RES], + 'torsion_angles_sin_cos': [NUM_RES, None, None], + 'msa_input': [NUM_MSA_SEQ, NUM_RES, 2], + 'query_input': [NUM_RES, 2], + 'additional_input': [NUM_RES, 4], + 'random_data': [NUM_NOISE, NUM_MSA_SEQ, NUM_RES, NUM_LATENT_DIM], + 'context_mask': [NUM_MSA_SEQ, 2] +} + +multimer_feature_list = { + "aatype": [NUM_RES], + "all_atom_mask": [NUM_RES, None], + "all_atom_positions": [NUM_RES, None, None], + "alt_chi_angles": [NUM_RES, None], + "atom14_alt_gt_exists": [NUM_RES, None], + "atom14_alt_gt_positions": [NUM_RES, None, None], + "atom14_atom_exists": [NUM_RES, None], + "atom14_atom_is_ambiguous": [NUM_RES, None], + "atom14_gt_exists": [NUM_RES, None], + "atom14_gt_positions": [NUM_RES, None, None], + "atom37_atom_exists": [NUM_RES, None], + "frame_mask": [NUM_RES], + "true_frame_tensor": [NUM_RES, None, None], + "bert_mask": [NUM_MSA_SEQ, NUM_RES], + "chi_angles_sin_cos": [NUM_RES, None, None], + "chi_mask": [NUM_RES, None], + "crop_and_fix_size_seed":[], + "deletion_matrix": [NUM_MSA_SEQ, NUM_RES], + "extra_msa_deletion_value": [NUM_EXTRA_SEQ, NUM_RES], + "extra_msa_has_deletion": [NUM_EXTRA_SEQ, NUM_RES], + "extra_msa": [NUM_EXTRA_SEQ, NUM_RES], + "extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES], + "extra_msa_row_mask": [NUM_EXTRA_SEQ], + "hhblits_profile": [NUM_RES, None], + "is_distillation": [], + "msa": [NUM_MSA_SEQ, NUM_RES], + "msa_feat": [NUM_MSA_SEQ, NUM_RES, None], + "msa_mask": [NUM_MSA_SEQ, NUM_RES], + "msa_chains": [NUM_MSA_SEQ, None], + "msa_row_mask": [NUM_MSA_SEQ], + "num_alignments": [], + "pseudo_beta": [NUM_RES, None], + "pseudo_beta_mask": [NUM_RES], + "residue_index": [NUM_RES], + "residx_atom14_to_atom37": [NUM_RES, None], + "residx_atom37_to_atom14": [NUM_RES, None], + "resolution": [], + "rigidgroups_alt_gt_frames": [NUM_RES, None, None, None], + "rigidgroups_group_exists": [NUM_RES, None], + "rigidgroups_group_is_ambiguous": [NUM_RES, None], + "rigidgroups_gt_exists": [NUM_RES, None], + "rigidgroups_gt_frames": [NUM_RES, None, None, None], + "seq_length": [], + "seq_mask": [NUM_RES], + "target_feat": [NUM_RES, None], + "template_aatype": [NUM_TEMPLATES, NUM_RES], + "template_all_atom_masks": [NUM_TEMPLATES, NUM_RES, None], + "template_all_atom_positions": [NUM_TEMPLATES, NUM_RES, None, None], + "template_alt_torsion_angles_sin_cos": [NUM_TEMPLATES, NUM_RES, None, None], + "template_frame_mask": [NUM_TEMPLATES, NUM_RES], + "template_frame_tensor": [NUM_TEMPLATES, NUM_RES, None, None], + "template_mask": [NUM_TEMPLATES], + "template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None], + "template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES], + "template_sum_probs": [NUM_TEMPLATES, None], + "template_torsion_angles_mask": [NUM_TEMPLATES, NUM_RES, None], + "template_torsion_angles_sin_cos": [NUM_TEMPLATES, NUM_RES, None, None], + "true_msa": [NUM_MSA_SEQ, NUM_RES], + "use_clamped_fape": [], + "assembly_num_chains": [1], + "asym_id": [NUM_RES], + "sym_id": [NUM_RES], + "entity_id": [NUM_RES], + "num_sym": [NUM_RES], + "asym_len": [None], + "cluster_bias_mask": [NUM_MSA_SEQ], +} + + +def feature_shape(feature_name, num_residues, msa_length, num_templates, features=None): + """Get the shape for the given feature name.""" + features = features or FEATURES + if feature_name.endswith("_unnormalized"): + feature_name = feature_name[:-13] + unused_dtype, raw_sizes = features.get(feature_name, (None, None)) + replacements = {NUM_RES: num_residues, + NUM_SEQ: msa_length} + + if num_templates is not None: + replacements[NUM_TEMPLATES] = num_templates + + sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] + for dimension in sizes: + if isinstance(dimension, str): + raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( + feature_name, raw_sizes, replacements)) + size_r = [int(x) for x in sizes] + return size_r + + +def parse_reshape_logic(parsed_features, features, num_template, key=None): + """Transforms parsed serial features to the correct shape.""" + # Find out what is the number of sequences and the number of alignments. + num_residues = np.reshape(parsed_features['seq_length'].astype(np.int32), (-1,))[0] + + if "num_alignments" in parsed_features: + num_msa = np.reshape(parsed_features["num_alignments"].astype(np.int32), (-1,))[0] + else: + num_msa = 0 + + if key is not None and "key" in features: + parsed_features["key"] = [key] # Expand dims from () to (1,). + + # Reshape the arrays according to the sequence length and num alignments. + for k, v in parsed_features.items(): + new_shape = feature_shape( + feature_name=k, + num_residues=num_residues, + msa_length=num_msa, + num_templates=num_template, + features=features) + new_shape_size = 1 + for dim in new_shape: + new_shape_size *= dim + + if np.size(v) != new_shape_size: + raise ValueError("the size of feature {} ({}) could not be reshaped into {}" + "".format(k, np.size(v), new_shape)) + + if "template" not in k: + # Make sure the feature we are reshaping is not empty. + if np.size(v) <= 0: + raise ValueError("The feature {} is not empty.".format(k)) + parsed_features[k] = np.reshape(v, new_shape) + + return parsed_features + + +def _make_features_metadata(feature_names): + """Makes a feature name to type and shape mapping from a list of names.""" + # Make sure these features are always read. + required_features = ["sequence", "domain_name", "template_domain_names"] + feature_names = list(set(feature_names) - set(required_features)) + + features_metadata = {name: FEATURES.get(name) for name in feature_names} + return features_metadata + + +def np_to_array_dict(np_example, features): + """Creates dict of arrays. + + Args: + np_example: A dict of NumPy feature arrays. + features: A list of strings of feature names to be returned in the dataset. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + features_metadata = _make_features_metadata(features) + array_dict = {k: v for k, v in np_example.items() if k in features_metadata} + if "template_domain_names" in np_example: + num_template = len(np_example["template_domain_names"]) + else: + num_template = 0 + + # Ensures shapes are as expected. Needed for setting size of empty features + # e.g. when no template hits were found. + array_dict = parse_reshape_logic(array_dict, features_metadata, num_template) + array_dict['template_mask'] = np.ones([num_template], np.float32) + return array_dict + + +class Feature: + """feature process""" + + def __init__(self, cfg, raw_feature=None, is_training=False, model_cfg=None, is_evogen=False, is_multimer=False): + if raw_feature and isinstance(raw_feature, dict): + self.ensemble_num = 0 + self.cfg = cfg + self.model_cfg = model_cfg + if 'deletion_matrix_int' in raw_feature: + raw_feature['deletion_matrix'] = (raw_feature.pop('deletion_matrix_int').astype(np.float32)) + feature_names = cfg.common.unsupervised_features + if cfg.common.use_templates: + feature_names += cfg.common.template_features + self.is_training = is_training + self.is_evogen = is_evogen + self.is_multimer = is_multimer + if self.is_training: + feature_names += cfg.common.supervised_features + if self.is_multimer: + feature_names += cfg.common.multimer_features + feature_names += cfg.common.recycling_features + raw_feature = {k: v for k, v in raw_feature.items() if k in feature_names} + raw_feature['template_all_atom_masks'] = (raw_feature.pop('template_all_atom_mask')) + if not self.is_multimer: + raw_feature = np_to_array_dict(np_example=raw_feature, features=feature_names) + # with open("/data6/yhding/1228/compare/myinit_feat.pkl", "wb") as f: + # pickle.dump(raw_feature, f) + for key in raw_feature: + setattr(self, key, raw_feature[key]) + + def non_ensemble(self, distillation=False, replace_proportion=0.0, use_templates=True): + """non ensemble""" + if self.is_multimer: + data = vars(self) + num_seq = data["msa"].shape[0] + seq_len = data["msa"].shape[1] + max_seq = self.cfg.common.max_msa_entry // seq_len + if num_seq > max_seq: + keep_index = (np.random.choice(num_seq - 1, max_seq - 1, replace=False) + 1) + keep_index = np.sort(keep_index) + keep_index = np.concatenate((np.array([0]), keep_index), axis=0) + for k in ["msa", "deletion_matrix", "msa_mask", "msa_row_mask", + "bert_mask", "true_msa", "msa_chains"]: + if k in data: + setattr(self, k, data[k][keep_index]) + if self.is_evogen: + msa, msa_input = correct_msa_restypes(self.msa, self.deletion_matrix, self.is_evogen) + setattr(self, "msa", msa) + setattr(self, "msa_input", msa_input.astype(np.float32)) + else: + setattr(self, "msa", correct_msa_restypes(self.msa)) + setattr(self, "is_distillation", np.array(float(distillation), dtype=np.float32)) + # convert int64 to int32 + for k, v in vars(self).items(): + if k not in ("ensemble_num", "is_training", "is_evogen", "cfg", "model_cfg", "is_multimer"): + if k.endswith("_mask"): + setattr(self, k, v.astype(np.float32)) + elif v.dtype in (np.int64, np.uint8, np.int8): + setattr(self, k, v.astype(np.int32)) + if len(self.aatype.shape) == 2: + aatype = np.argmax(self.aatype, axis=-1) + setattr(self, "aatype", aatype.astype(np.int32)) + if self.is_evogen: + query_input = np.concatenate((aatype[:, None], self.deletion_matrix[0]), + axis=-1).astype(np.int32) + setattr(self, "query_input", query_input.astype(np.float32)) + data = vars(self) + if "resolution" in data and len(data["resolution"].shape) == 1: + setattr(self, "resolution", data["resolution"][0]) + namelist = ['msa', 'num_alignments', 'seq_length', 'sequence', 'superfamily', 'deletion_matrix', + 'resolution', 'between_segment_residues', 'residue_index', 'template_all_atom_masks'] + if self.is_multimer: + namelist.append('domain_name') + namelist.remove('resolution') + for k in namelist: + if k in data: + final_dim = data[k].shape[-1] + if isinstance(final_dim, int) and final_dim == 1: + setattr(self, k, np.squeeze(data[k], axis=-1)) + # Remove fake sequence dimension + for k in ['seq_length', 'num_alignments']: + if k in data and len(data[k].shape): + setattr(self, k, data[k][0]) + msa, aatype = randomly_replace_msa_with_unknown(self.msa, self.aatype, replace_proportion) + setattr(self, "msa", msa) + setattr(self, "aatype", aatype) + # seq_mask + seq_mask = np.ones(self.aatype.shape, dtype=np.float32) + setattr(self, "seq_mask", seq_mask) + # msa_mask and msa_row_mask + msa_mask = np.ones(self.msa.shape, dtype=np.float32) + msa_row_mask = np.ones(self.msa.shape[0], dtype=np.float32) + setattr(self, "msa_mask", msa_mask) + setattr(self, "msa_row_mask", msa_row_mask) + if 'hhblits_profile' not in data: + # Compute the profile for every residue (over all MSA sequences). + if self.is_multimer: + setattr(self, 'hhblits_profile', np.mean(one_hot(22, self.msa) * self.msa_mask[:, :, None], axis=0)) + else: + setattr(self, 'hhblits_profile', np.mean(one_hot(22, self.msa), axis=0)) + if use_templates: + if not self.is_multimer: + template_aatype = fix_templates_aatype(self.template_aatype) + setattr(self, "template_aatype", template_aatype) + else: + setattr(self, "template_mask", np.ones(self.template_aatype.shape[0], dtype=np.float32)) + template_pseudo_beta, template_pseudo_beta_mask = pseudo_beta_fn(self.template_aatype, + self.template_all_atom_positions, + self.template_all_atom_masks) + setattr(self, "template_pseudo_beta", template_pseudo_beta) + setattr(self, "template_pseudo_beta_mask", template_pseudo_beta_mask) + if self.is_multimer: + num_templates = self.template_mask.shape[-1] + max_templates = self.cfg.common.max_templates + if num_templates > 0: + if self.cfg.common.subsample_templates: + max_templates = min(max_templates, np.random.randint(0, num_templates + 1)) + template_idx = np.random.choice(num_templates, max_templates, replace=False) + else: + # use top templates + template_idx = np.arange(min(num_templates, max_templates), dtype=np.int64) + for k, v in vars(self).items(): + if k.startswith("template"): + try: + v = v[template_idx] + except Exception as ex: + print(ex.__class__, ex) + print("num_templates", num_templates) + print(k, v.shape) + print("protein_shape:", {k: v.shape for k, v in vars(self).items() if "shape" in dir(v)}) + setattr(self, k, v) + if self.cfg.common.use_template_torsion_angles: + aatype = self.template_aatype + all_atom_positions = self.template_all_atom_positions + all_atom_mask = self.template_all_atom_masks + angle_arrays_feature = atom37_to_torsion_angles(aatype, all_atom_positions, all_atom_mask, alt_torsions=False, is_multimer=self.is_multimer) + setattr(self, "template_torsion_angles_sin_cos", angle_arrays_feature["torsion_angles_sin_cos"]) + setattr(self, "template_alt_torsion_angles_sin_cos", angle_arrays_feature["alt_torsion_angles_sin_cos"]) + setattr(self, "template_torsion_angles_mask", angle_arrays_feature["torsion_angles_mask"]) + + atom14_atom_exists, residx_atom14_to_atom37, residx_atom37_to_atom14, atom37_atom_exists = \ + make_atom14_masks(self.aatype) + setattr(self, "atom14_atom_exists", atom14_atom_exists) + setattr(self, "residx_atom14_to_atom37", residx_atom14_to_atom37) + setattr(self, "residx_atom37_to_atom14", residx_atom37_to_atom14) + setattr(self, "atom37_atom_exists", atom37_atom_exists) + + if self.is_multimer: + if "between_segment_residues" in vars(self).keys(): + has_break = np.clip(self.between_segment_residues.astype(np.float32), 0, 1) + else: + has_break = np.zeros_like(self.aatype, dtype=np.float32) + if "asym_len" in vars(self): + asym_len = self.asym_len + entity_ends = np.cumsum(asym_len, axis=-1)[:-1] + has_break[entity_ends] = 1.0 + has_break = has_break.astype(np.float32) + aatype_1hot = one_hot(21, self.aatype) + if self.cfg.common.target_feat_dim == 22: + target_feat = [np.expand_dims(has_break, axis=-1), aatype_1hot] + else: + target_feat = [aatype_1hot] + setattr(self, "target_feat", np.concatenate(target_feat, axis=-1)) + + def ensemble(self, data, msa_fraction_per_block=0.3, randomize_num_blocks=True, num_blocks=5, keep_extra=True, + max_msa_clusters=124, masked_msa=None, uniform_prob=0.1, profile_prob=0.1, same_prob=0.1, + replace_fraction=0.15, msa_cluster_features=True, max_extra_msa=1024, crop_size=256, max_templates=4, + subsample_templates=True, fixed_size=True, seed=0, random_recycle=False): + """ensemble""" + if not self.is_multimer: + self.ensemble_num += 1 + if self.is_training: + keep_indices = block_delete_msa_indices(data["msa"], msa_fraction_per_block, randomize_num_blocks, + num_blocks) + for k in _MSA_FEATURE_NAMES: + if k in data: + data[k] = data[k][keep_indices] + is_sel, not_sel_seq, sel_seq = sample_msa(data["msa"], max_msa_clusters) + + # ensure first row of msa is input sequence + data["msa"] = np.concatenate([data["aatype"][None,:], data["msa"]], axis=0) + zero_deletion = np.zeros((data["deletion_matrix"].shape[-1])).astype(data["deletion_matrix"].dtype) + data["deletion_matrix"] = np.concatenate([zero_deletion[None,:], data["deletion_matrix"]], axis=0) + + # exist numpy random op + if self.is_multimer: + # print(data["is_distillation"]) + is_sel, not_sel_seq, sel_seq = sample_msa_v2(data["msa"], data["msa_chains"], data["msa_mask"], + max_msa_clusters, biased_msa_by_chain=self.cfg.common.biased_msa_by_chain) # True + # print(is_sel, not_sel_seq, sel_seq) # 正确 + if "msa_input" in _MSA_FEATURE_NAMES: + _MSA_FEATURE_NAMES.remove("msa_input") + _MSA_FEATURE_NAMES.append("msa_chains") + + for k in _MSA_FEATURE_NAMES: + if k in data: + if keep_extra and not is_sel: + new_shape = list(data[k].shape) + new_shape[0] = 1 + data['extra_' + k] = np.zeros(new_shape) + elif keep_extra and is_sel: + data['extra_' + k] = data[k][not_sel_seq] + if k == 'msa' and not self.is_multimer: + data['extra_msa'] = data['extra_msa'].astype(np.int32) + data[k] = data[k][sel_seq] + if masked_msa: + if self.is_evogen: + make_masked_msa_result = make_masked_msa( + data["msa"], data["hhblits_profile"], + uniform_prob, profile_prob, + same_prob, + replace_fraction, + data['residue_index'], data['msa_mask'], self.is_evogen) + data["bert_mask"], data["true_msa"], data["msa"], data["additional_input"] = make_masked_msa_result + data["additional_input"] = data["additional_input"].astype(np.float32) + elif self.is_multimer: + + data["bert_mask"], data["true_msa"], data["msa"] = make_masked_msa_v2(data["msa"], + data["hhblits_profile"], + data['msa_mask'], + data["entity_id"], + data["sym_id"], + data["num_sym"], + uniform_prob, + profile_prob, + same_prob, + replace_fraction, + share_mask=self.cfg.common.share_mask, #True + bert_mask=data["bert_mask"]) + else: + data["bert_mask"], data["true_msa"], data["msa"] = make_masked_msa(data["msa"], data["hhblits_profile"], + uniform_prob, profile_prob, + same_prob, + replace_fraction) + + if msa_cluster_features: + if self.is_multimer: + data["cluster_profile"], data["cluster_deletion_mean"] = nearest_neighbor_clusters_v2(data["msa"], + data["msa_mask"], + data["extra_msa"], + data["extra_msa_mask"], + data["deletion_matrix"], + data["extra_deletion_matrix"]) + else: + data["extra_cluster_assignment"] = nearest_neighbor_clusters(data["msa_mask"], data["msa"], + data["extra_msa_mask"], data["extra_msa"]) + data["cluster_profile"], data["cluster_deletion_mean"] = summarize_clusters(data["msa"], data["msa_mask"], + data[ + "extra_cluster_assignment"], + data["extra_msa_mask"], + data["extra_msa"], + data["extra_deletion_matrix"], + data["deletion_matrix"]) + + if self.is_multimer: + data["msa_feat"] = make_msa_feat_v2(data["msa"], data["deletion_matrix"], data["cluster_deletion_mean"], data["cluster_profile"]) + # with open("/data6/yhding/1228/ensemble_compare/my_make_msa_feat.pkl", "wb") as f: + # pickle.dump(data, f) + extra_feats = make_extra_msa_feat(data["extra_msa"], data["extra_deletion_matrix"], data["extra_msa_mask"], self.cfg.common.max_extra_msa) + data["extra_msa"] = extra_feats["extra_msa"] + data["extra_msa_mask"] = extra_feats["extra_msa_mask"] + data["extra_msa_has_deletion"] = extra_feats["extra_msa_has_deletion"] + data["extra_msa_deletion_value"] = extra_feats["extra_msa_deletion_value"] + + else: + if max_extra_msa: + select_indices = crop_extra_msa(data["extra_msa"], max_extra_msa) + if select_indices: + for k in _MSA_FEATURE_NAMES: + if 'extra_' + k in data: + data['extra_' + k] = data['extra_' + k][select_indices] + else: + for k in _MSA_FEATURE_NAMES: + if 'extra_' + k in data: + del data['extra_' + k] + data["extra_has_deletion"], data["extra_deletion_value"], data["msa_feat"], data["target_feat"] = make_msa_feat( + data["between_segment_residues"], data["aatype"], data["msa"], data["deletion_matrix"], + data["cluster_deletion_mean"], data["cluster_profile"], data["extra_deletion_matrix"]) + + if fixed_size: + data = {k: v for k, v in data.items() if k in feature_list} + + num_res_crop_size, num_templates_crop_size_int, num_res_crop_start, num_res_crop_size_int, \ + templates_crop_start, templates_select_indices = random_crop_to_size( + data["seq_length"], data["template_mask"], crop_size, max_templates, + subsample_templates, seed, random_recycle) + for k, v in data.items(): + if k not in feature_list or ('template' not in k and NUM_RES not in feature_list.get(k)): + continue + + # randomly permute the templates before cropping them. + if k.startswith('template') and subsample_templates: + v = v[templates_select_indices] + + crop_sizes = [] + crop_starts = [] + for i, (dim_size, dim) in enumerate(zip(feature_list.get(k), v.shape)): + is_num_res = (dim_size == NUM_RES) + if i == 0 and k.startswith('template'): + crop_size_ = num_templates_crop_size_int + crop_start = templates_crop_start + else: + crop_start = num_res_crop_start if is_num_res else 0 + crop_size_ = (num_res_crop_size_int if is_num_res else (-1 if dim is None else dim)) + crop_sizes.append(crop_size_) + crop_starts.append(crop_start) + if len(v.shape) == 1: + data[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0]] + elif len(v.shape) == 2: + data[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1]] + elif len(v.shape) == 3: + data[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1], + crop_starts[2]:crop_starts[2] + crop_sizes[2]] + else: + data[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1], + crop_starts[2]:crop_starts[2] + crop_sizes[2], + crop_starts[3]:crop_starts[3] + crop_sizes[3]] + + data["seq_length"] = num_res_crop_size + + pad_size_map = { + NUM_RES: crop_size, + NUM_MSA_SEQ: max_msa_clusters, + NUM_EXTRA_SEQ: max_extra_msa, + NUM_TEMPLATES: max_templates, + } + + for k, v in data.items(): + if k == 'extra_cluster_assignment': + continue + shape = list(v.shape) + schema = feature_list.get(k) + assert len(shape) == len( + schema), f'Rank mismatch between shape and shape schema for {k}: {shape} vs {schema}' + + pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + if padding: + data[k] = np.pad(v, padding) + data[k].reshape(pad_size) + else: + for k, v in data.items(): + if k.startswith('template_'): + data[k] = v[:max_templates] + if self.is_evogen: + data["random_data"], data["context_mask"] = generate_random_sample(self.cfg, self.model_cfg) + data["context_mask"] = data["context_mask"].astype(np.float32) + return data + + def process_res(self, features, res, dtype): + """process result""" + arrays, prev_pos, prev_msa_first_row, prev_pair = res + if self.is_evogen: + evogen_keys = ["target_feat", "seq_mask", "aatype", "residx_atom37_to_atom14", "atom37_atom_exists", + "residue_index", "msa_mask", "msa_input", "query_input", "additional_input", "random_data", + "context_mask"] + arrays = [features[key] for key in evogen_keys] + arrays = [array.astype(dtype) if array.dtype == "float64" else array for array in arrays] + arrays = [array.astype(dtype) if array.dtype == "float32" else array for array in arrays] + res = [arrays, prev_pos, prev_msa_first_row, prev_pair] + return res + if self.is_training: + label_keys = ["pseudo_beta", "pseudo_beta_mask", "all_atom_mask", + "true_msa", "bert_mask", "residue_index", "seq_mask", + "atom37_atom_exists", "aatype", "residx_atom14_to_atom37", + "atom14_atom_exists", "backbone_affine_tensor", "backbone_affine_mask", + "atom14_gt_positions", "atom14_alt_gt_positions", + "atom14_atom_is_ambiguous", "atom14_gt_exists", "atom14_alt_gt_exists", + "all_atom_positions", "rigidgroups_gt_frames", "rigidgroups_gt_exists", + "rigidgroups_alt_gt_frames", "torsion_angles_sin_cos", "chi_mask"] + label_arrays = [features[key] for key in label_keys] + label_arrays = [array[0] for array in label_arrays] + label_arrays = [array.astype(dtype) if array.dtype == "float64" else array for array in label_arrays] + label_arrays = [array.astype(dtype) if array.dtype == "float32" else array for array in label_arrays] + res = [arrays, prev_pos, prev_msa_first_row, prev_pair, label_arrays] + return res + return res + + + def crop_and_fix_size(self, features, crop_and_fix_size_seed): + crop_feats = dict(multimer_feature_list) + crop_and_fix_size_seed = int(crop_and_fix_size_seed) + with numpy_seed(crop_and_fix_size_seed, key="multimer_crop"): + use_spatial_crop = np.random.rand() < self.cfg.common.spatial_crop_prob # 0.5 + if use_spatial_crop: + crop_idx = get_spatial_crop_idx(features, crop_size=self.cfg.common.crop_size, random_seed=crop_and_fix_size_seed, ca_ca_threshold=self.cfg.common.ca_ca_threshold) + # crop_idx = get_spatial_crop_idx_v2(features, crop_size=self.cfg.common.crop_size, random_seed=crop_and_fix_size_seed, ca_ca_threshold=self.cfg.common.ca_ca_threshold) + else: + crop_idx = get_contiguous_crop_idx(features, crop_size=self.cfg.common.crop_size, random_seed=crop_and_fix_size_seed) + # print(len(crop_idx), features["msa"].shape) + + features = apply_crop_idx(features, shape_schema=crop_feats, crop_idx=crop_idx) + + # show_npdict(features, "crop but not pad") + + return features + + def pipeline(self, cfg, mixed_precision=True, seed=0): + """feature process pipeline""" + self.non_ensemble(cfg.common.distillation, cfg.common.replace_proportion, cfg.common.use_templates) + non_ensemble_data = vars(self).copy() + + crop_and_fix_size_seed = seed + num_recycling = self.cfg.common.num_recycle + 1 # 3 + 1 + num_ensembles = self.cfg.common.num_ensembles # 1 + max_msa_clusters = self.cfg.common.max_msa_clusters - self.cfg.common.max_templates #256-4 + max_extra_msa = self.cfg.common.max_extra_msa #1024 + def wrap_ensemble(data, i): + d = data.copy() + + d = self.ensemble(d, max_msa_clusters=max_msa_clusters, #252 + max_extra_msa=max_extra_msa, #1024 + masked_msa=self.cfg.common.use_masked_msa, # True + profile_prob=self.cfg.common.profile_prob, # 0.1 + same_prob=self.cfg.common.same_prob, # 0.1 + uniform_prob=self.cfg.common.uniform_prob, # 0.1 + replace_fraction=self.cfg.common.replace_fraction, # 0.15 + msa_cluster_features=self.cfg.common.msa_cluster_features) #True + + # d = self.crop_and_fix_size(d, crop_and_fix_size_seed) + + if self.cfg.common.reduce_msa_clusters_by_max_templates: # True + pad_msa_clusters = self.cfg.common.max_msa_clusters - self.cfg.common.max_templates + else: + pad_msa_clusters = self.cfg.common.max_msa_clusters + crop_feats = dict(multimer_feature_list) + d = select_feat(d, crop_feats) + d = make_fixed_size(d, crop_feats, + pad_msa_clusters, # 252 + self.cfg.common.max_extra_msa, # 1024 + self.cfg.common.crop_size, # 384 + self.cfg.common.max_templates) # 4 + + return d + + features = non_ensemble_data.copy() + + features.pop("cfg") + features_new = self.crop_and_fix_size(features, crop_and_fix_size_seed) + for key in list(set(list(features.keys())) - set(list(features_new.keys()))): + features_new[key] = features[key] + features = features_new + features["seq_length"] = np.array(features["msa"].shape[1]) + # print('\n\n====================== features after crop ===================') + # show_npdict(features) + ensemble_features = map_fn( + lambda x: wrap_ensemble(features, x), + np.arange(num_recycling * num_ensembles) + ) + + if self.cfg.common.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = self.cfg.common.max_msa_clusters - self.cfg.common.max_templates + else: + pad_msa_clusters = self.cfg.common.max_msa_clusters + crop_feats = dict(multimer_feature_list) + processed_features = select_feat(features, crop_feats) + processed_features = make_fixed_size(processed_features, crop_feats, + pad_msa_clusters, + self.cfg.common.max_extra_msa, + self.cfg.common.crop_size, + self.cfg.common.max_templates) + processed_features = {k: np.stack([processed_features[k]], axis=0) for k in processed_features} + + np.set_printoptions(threshold=np.inf) + processed_features.update(ensemble_features) + # show_npdict(processed_features, "feats after ensemble") + # print(processed_features["num_sym"].shape, flush=True) + + # print(f"\n\n==========================ori processed_feat before duplicating") + # # for key, value in all_labels[0].items(): + # # print(key, value.shape, value.dtype, flush=True) + # keys = list(processed_features.keys()) + # keys.sort() + # for key in keys: + # value = processed_features[key] + # print(key, value.shape, value.dtype, flush=True) + + # for key, value in processed_features.items(): + # if value.shape[0] == 1: + # processed_features[key] = np.concatenate([value] * num_recycling, axis=0) + + # print(f"\n\n==========================ori processed_feat") + # # for key, value in all_labels[0].items(): + # # print(key, value.shape, value.dtype, flush=True) + # keys = list(processed_features.keys()) + # keys.sort() + # for key in keys: + # value = processed_features[key] + # print(key, value.shape, value.dtype, flush=True) + + def custom_padding(seq_length, array, dim, res_length): + """Pad array to fixed size.""" + padding_size = seq_length - res_length + extra_array_shape = list(array.shape) + extra_array_shape[dim] = padding_size + extra_array = np.zeros(extra_array_shape, dtype=array.dtype) + array = np.concatenate((array, extra_array), axis=dim) + return array + + + crop_1_dim_key = ['aatype', 'target_feat', 'residx_atom37_to_atom14', 'atom37_atom_exists', + 'residue_index', 'asym_id', 'sym_id', 'entity_id', 'seq_mask', "num_sym"] + crop_2_dim_key = ['msa_feat', 'template_aatype', 'template_all_atom_masks', 'template_all_atom_positions', + 'extra_msa', 'extra_msa_deletion_value', 'extra_msa_mask', 'msa_mask', "bert_mask", "true_msa"] + + res_length = processed_features["msa_feat"].shape[2] + for key in crop_1_dim_key: + processed_features[key] = custom_padding(self.cfg.common.crop_size, processed_features[key], 1, res_length) + for key in crop_2_dim_key: + processed_features[key] = custom_padding(self.cfg.common.crop_size, processed_features[key], 2, res_length) + + num_extra_seq = processed_features['extra_msa'].shape[1] + if num_extra_seq < self.cfg.common.max_extra_msa: + for key in ["extra_msa", "extra_msa_mask", "extra_msa_deletion_value"]: + processed_features[key] = custom_padding(self.cfg.common.max_extra_msa, processed_features[key], 1, num_extra_seq) + else: + for key in ["extra_msa", "extra_msa_mask", "extra_msa_deletion_value"]: + processed_features[key] = processed_features[key][:, :self.cfg.common.max_extra_msa, :] + + processed_features["extra_msa_deletion_value"] = processed_features["extra_msa_deletion_value"] + dtype = np.float16 + for key, value in processed_features.items(): + if value.dtype == "float64": + # print(key, "hello, float64") + processed_features[key] = value.astype(dtype) + # print(processed_features[key].dtype) + if value.dtype == "float32": + processed_features[key] = value.astype(dtype) + + + + # print(f"\n\n==========================processed_feat after padding") + # # for key, value in all_labels[0].items(): + # # print(key, value.shape, value.dtype, flush=True) + # keys = list(processed_features.keys()) + # keys.sort() + # for key in keys: + # value = processed_features[key] + # print(key, value.shape, value.dtype, flush=True) + # show_npdict(processed_features, 'processed_feat after padding') + + + input_keys = ['aatype', 'residue_index', 'template_aatype', 'template_all_atom_masks', + 'template_all_atom_positions', 'asym_id', 'sym_id', 'entity_id', 'seq_mask', 'msa_mask', + 'target_feat', 'msa_feat', 'extra_msa', 'extra_msa_deletion_value', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists'] + + # input_keys.sort() + # print(f"\n\n==========================infer input") + # print(processed_features["asym_id"][0]) + # print(processed_features["sym_id"][0]) + # # import time + # # time.sleep(10) + # print(processed_features["entity_id"][0]) + # print(processed_features["residue_index"][0]) + res_arrays = [] + for key in input_keys: + value = processed_features[key] + res_arrays.append(value) + # print(key, value.shape, value.dtype) + # print(np.sum(np.abs(processed_features["msa_feat"][1] - processed_features["msa_feat"][0]))) + # print(np.sum(np.abs(processed_features["msa_feat"][2] - processed_features["msa_feat"][1]))) + + prev_pos = np.zeros([self.cfg.common.crop_size, 37, 3]).astype(dtype) + prev_msa_first_row = np.zeros([self.cfg.common.crop_size, 256]).astype(dtype) + prev_pair = np.zeros([self.cfg.common.crop_size, self.cfg.common.crop_size, 128]).astype(dtype) + num_sym = processed_features["num_sym"][0] + bert_mask = processed_features["bert_mask"] + true_msa = processed_features["true_msa"] + res = [res_arrays, prev_pos, prev_msa_first_row, prev_pair, num_sym, bert_mask, true_msa] + + return res + + + +class MultimerFeature: + """multimer feature process""" + + def __init__(self, mixed_precision=True): + self.mixed_precision = mixed_precision + + def np_mask_mean(self, mask, value, axis=None, drop_mask_channel=False, eps=1e-10): + """Numpy masked mean.""" + if drop_mask_channel: + mask = mask[..., 0] + mask_shape = mask.shape + value_shape = value.shape + broadcast_factor = 1. + value_size = value_shape[axis] + mask_size = mask_shape[axis] + if mask_size == 1: + broadcast_factor *= value_size + return np.sum(mask * value, axis=axis) / (np.sum(mask, axis=axis) * broadcast_factor + eps) + + def sample_msa(self, raw_features, max_seq): + """Sample MSA randomly.""" + logits = (np.clip(np.sum(raw_features['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6 + if 'cluster_bias_mask' not in raw_features: + cluster_bias_mask = np.pad( + np.zeros(raw_features['msa'].shape[0] - 1), (1, 0), constant_values=1.) + else: + cluster_bias_mask = raw_features['cluster_bias_mask'] + logits += cluster_bias_mask * 1e6 + z = np.random.gumbel(loc=0.0, scale=1.0, size=logits.shape) + index_order = np.argsort(-(logits + z), axis=-1, kind='quicksort', order=None) + sel_idx = index_order[:max_seq] + extra_idx = index_order[max_seq:] + for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']: + if k in raw_features: + raw_features['extra_' + k] = raw_features[k][extra_idx] + raw_features[k] = raw_features[k][sel_idx] + return raw_features + + def make_masked_msa(self, raw_features, config, epsilon=1e-6): + """create data for BERT on raw MSA.""" + random_aa = np.array([0.05] * 20 + [0., 0.], dtype=np.float32) + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * raw_features['msa_profile'] + + config.same_prob * np.eye(22)[raw_features['msa']]) + pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] + pad_shapes[-1][1] = 1 + mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob + categorical_probs = np.pad(categorical_probs, pad_shapes, constant_values=mask_prob) + sh = raw_features['msa'].shape + mask_position = (np.random.uniform(0., 1., sh) < config.replace_fraction).astype(np.float32) + mask_position *= raw_features['msa_mask'] + logits = np.log(categorical_probs + epsilon) + z = np.random.gumbel(loc=0.0, scale=1.0, size=logits.shape) + bert_msa = np.eye(logits.shape[-1], dtype=logits.dtype)[np.argmax(logits + z, axis=-1)] + bert_msa = (np.where(mask_position, + np.argmax(bert_msa, axis=-1), raw_features['msa'])) + bert_msa *= (raw_features['msa_mask'].astype(np.int64)) + if 'bert_mask' in raw_features: + raw_features['bert_mask'] *= mask_position.astype(np.float32) + else: + raw_features['bert_mask'] = mask_position.astype(np.float32) + raw_features['true_msa'] = raw_features['msa'] + raw_features['msa'] = bert_msa + return raw_features + + def softmax(self, x, axis): + """ Softmax func""" + x -= np.max(x, axis=axis, keepdims=True) + x = np.exp(x) / np.sum(np.exp(x), axis=axis, keepdims=True) + return x + + def nearest_neighbor_clusters(self, raw_features, gap_agreement_weight=0.): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + weights = np.array( + [1.] * 21 + [gap_agreement_weight] + [0.], dtype=np.float32) + msa_mask = raw_features['msa_mask'] + msa_one_hot = np.eye(23)[raw_features['msa']] + extra_mask = raw_features['extra_msa_mask'] + extra_one_hot = np.eye(23)[raw_features['extra_msa']] + msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot + extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot + agreement = np.einsum('mrc, nrc->nm', extra_one_hot_masked, + weights * msa_one_hot_masked) + cluster_assignment = self.softmax(1e3 * agreement, axis=0) + cluster_assignment *= np.einsum('mr, nr->mn', msa_mask, extra_mask) + cluster_count = np.sum(cluster_assignment, axis=-1) + cluster_count += 1. + msa_sum = np.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked) + msa_sum += msa_one_hot_masked + cluster_profile = msa_sum / cluster_count[:, None, None] + extra_deletion_matrix = raw_features['extra_deletion_matrix'] + deletion_matrix = raw_features['deletion_matrix'] + del_sum = np.einsum('nm, mc->nc', cluster_assignment, + extra_mask * extra_deletion_matrix) + del_sum += deletion_matrix + cluster_deletion_mean = del_sum / cluster_count[:, None] + return cluster_profile, cluster_deletion_mean + + def create_msa_feat(self, raw_features): + """Create and concatenate MSA features.""" + msa_1hot = np.eye(23)[raw_features['msa']] + deletion_matrix = raw_features['deletion_matrix'] + has_deletion = np.clip(deletion_matrix, 0., 1.)[..., None] + deletion_value = (np.arctan(deletion_matrix / 3.) * (2. / np.pi))[..., None] + deletion_mean_value = (np.arctan(raw_features['cluster_deletion_mean'] / 3.) * + (2. / np.pi))[..., None] + msa_feat = [ + msa_1hot, + has_deletion, + deletion_value, + raw_features['cluster_profile'], + deletion_mean_value + ] + return np.concatenate(msa_feat, axis=-1) + + def custom_padding(self, seq_length, array, dim, res_length): + """Pad array to fixed size.""" + padding_size = seq_length - res_length + extra_array_shape = list(array.shape) + extra_array_shape[dim] = padding_size + extra_array = np.zeros(extra_array_shape, dtype=array.dtype) + array = np.concatenate((array, extra_array), axis=dim) + return array + + def pipeline(self, model_cfg, data_cfg, raw_feature): + """Preprocesses Numpy feature dict in multimer model""" + if not data_cfg.common.random_recycle: + np.random.seed(0) + + features = raw_feature.copy() + features['msa_profile'] = self.np_mask_mean(features['msa_mask'][:, :, None], + np.eye(22)[features['msa']], axis=0) + + features['target_feat'] = np.eye(21)[features['aatype']] + + # if data_cfg.common.target_feat_dim == 22: + # bsr = np.zeros_like(features["aatype"], dtype=np.float32) + # has_break = np.clip(bsr, 0, 1) + # features["target_feat"] = np.concatenate([np.expand_dims(has_break, axis=-1), features['target_feat']], axis=-1) + + # print(features["target_feat"].shape) + + + features = self.sample_msa(features, model_cfg.multimer.embeddings_and_evoformer.num_msa) + features = self.make_masked_msa(features, model_cfg.multimer.embeddings_and_evoformer.masked_msa) + (features['cluster_profile'], features['cluster_deletion_mean']) = self.nearest_neighbor_clusters(features) + features['msa_feat'] = self.create_msa_feat(features) + res_length = features['aatype'].shape[0] + _, _, features['residx_atom37_to_atom14'], features['atom37_atom_exists'] = \ + make_atom14_masks(features['aatype']) + crop_0_dim_key = ['aatype', 'target_feat', 'residx_atom37_to_atom14', 'atom37_atom_exists', + 'residue_index', 'asym_id', 'sym_id', 'entity_id', 'seq_mask'] + crop_1_dim_key = ['msa_feat', 'template_aatype', 'template_all_atom_mask', 'template_all_atom_positions', + 'extra_msa', 'extra_deletion_matrix', 'extra_msa_mask', 'msa_mask'] + for key in crop_0_dim_key: + features[key] = self.custom_padding(model_cfg.seq_length, features[key], 0, res_length) + for key in crop_1_dim_key: + features[key] = self.custom_padding(model_cfg.seq_length, features[key], 1, res_length) + num_extra_seq = features['extra_msa'].shape[0] + if num_extra_seq < data_cfg.common.max_extra_msa: + for key in ["extra_msa", "extra_msa_mask", "extra_deletion_matrix"]: + features[key] = self.custom_padding(data_cfg.common.max_extra_msa, features[key], 0, num_extra_seq) + else: + for key in ["extra_msa", "extra_msa_mask", "extra_deletion_matrix"]: + features[key] = features[key][:data_cfg.common.max_extra_msa, :] + + features['extra_deletion_matrix'] = np.arctan(features['extra_deletion_matrix'] / 3.) * (2. / np.pi) + input_keys = ['aatype', 'residue_index', 'template_aatype', 'template_all_atom_mask', + 'template_all_atom_positions', 'asym_id', 'sym_id', 'entity_id', 'seq_mask', 'msa_mask', + 'target_feat', 'msa_feat', 'extra_msa', 'extra_deletion_matrix', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists'] + dtype = np.float32 + if self.mixed_precision: + dtype = np.float16 + print("msa_feat_sum", np.sum(features["msa_feat"]), flush=True) + arrays = [features[key] for key in input_keys] + arrays = [array.astype(dtype) if array.dtype == "float64" else array for array in arrays] + arrays = [array.astype(dtype) if array.dtype == "float32" else array for array in arrays] + return arrays + diff --git a/MindSPONGE/applications/research/Grasp/data/protein_feature.py b/MindSPONGE/applications/research/Grasp/data/protein_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..52db7991eb2760a76df21008a39b1a8d37c88565 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/protein_feature.py @@ -0,0 +1,168 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +protein feature generation module. +""" + +import numpy as np +from absl import logging + +from mindsponge1.data.data_transform import convert_monomer_features, convert_unnecessary_leading_dim_feats +from mindsponge1.common import residue_constants +from data.templates import TemplateHitFeaturizer +from data.hhsearch import HHSearch +from data.msa_query import MmseqQuery +from data.multimer_pipeline import add_assembly_features, pair_and_merge, pad_msa +from data.parsers import parse_fasta, parse_hhr, parse_a3m + + +def make_msa_features(msas, deletion_matrices): + """Constructs a feature dict of MSA features.""" + if not msas: + raise ValueError('At least one MSA must be provided.') + + int_msa = [] + deletion_matrix = [] + seen_sequences = set() + for msa_index, msa in enumerate(msas): + if not msa: + raise ValueError(f'MSA {msa_index} must contain at least one sequence.') + for sequence_index, sequence in enumerate(msa): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append([residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) + deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) + + num_res = len(msas[0][0]) + num_alignments = len(int_msa) + features = {'deletion_matrix_int': np.array(deletion_matrix, dtype=np.int32), + 'deletion_matrix_int_all_seq': np.array(deletion_matrix, dtype=np.int32), + 'msa': np.array(int_msa, dtype=np.int32), + 'msa_all_seq': np.array(int_msa, dtype=np.int32), + 'num_alignments': np.array([num_alignments] * num_res, dtype=np.int32), + 'msa_species_identifiers_all_seq': np.array([b''] * num_alignments)} + return features + + +def make_sequence_features(sequence: str, description: str, num_res: int): + """Constructs a feature dict of sequence features.""" + features = {'aatype': residue_constants.sequence_to_onehot(sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True), + 'between_segment_residues': np.zeros((num_res,), dtype=np.int32), + 'domain_name': np.array([description.encode('utf-8')], dtype=np.object_), + 'residue_index': np.array(range(num_res), dtype=np.int32), + 'seq_length': np.array([num_res] * num_res, dtype=np.int32), + 'sequence': np.array([sequence.encode('utf-8')], dtype=np.object_)} + return features + + +class RawFeatureGenerator: + """Runs the alignment tools""" + + def __init__(self, database_search_config, max_hits=20, msa_length=512): + """Search the a3m info for a given FASTA file.""" + + + self.template_mmcif_dir = database_search_config.mmcif_dir + self.max_template_date = database_search_config.max_template_date + self.kalign_binary_path = database_search_config.kalign_binary_path + self.obsolete_pdbs_path = database_search_config.obsolete_pdbs_path + self.hhsearch_binary_path = database_search_config.hhsearch_binary_path + self.pdb70_database_path = database_search_config.pdb70_database_path + self.a3m_result_path = database_search_config.a3m_result_path + self.database_envdb_dir = database_search_config.database_envdb_dir + self.mmseqs_binary = database_search_config.mmseqs_binary + self.uniref30_path = database_search_config.uniref30_path + self.max_hits = max_hits + self.msa_length = msa_length + self.msa_query = MmseqQuery(database_envdb_dir=self.database_envdb_dir, + mmseqs_binary=self.mmseqs_binary, + uniref30_path=self.uniref30_path, + result_path=self.a3m_result_path) + self.hhsearch_pdb70_runner = HHSearch(binary_path=self.hhsearch_binary_path, + databases=[self.pdb70_database_path]) + + + def monomer_feature_generate(self, fasta_path): + """protein raw feature generation""" + template_featurizer = TemplateHitFeaturizer(mmcif_dir=self.template_mmcif_dir, + max_template_date=self.max_template_date, + max_hits=self.max_hits, + kalign_binary_path=self.kalign_binary_path, + release_dates_path=None, + obsolete_pdbs_path=self.obsolete_pdbs_path) + with open(fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError(f'More than one input sequence found in {fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + + num_res = len(input_sequence) + a3m_lines = self.msa_query.aligned_a3m_files(fasta_path, self.a3m_result_path) + + hhsearch_result = self.hhsearch_pdb70_runner.query(a3m_lines) + hhsearch_hits = parse_hhr(hhsearch_result) + + msas, deletion_matrices = parse_a3m(a3m_lines) + templates_result = template_featurizer.get_templates( + query_sequence=input_sequence, + query_pdb_code=None, + query_release_date=None, + hhr_hits=hhsearch_hits) + sequence_features = make_sequence_features( + sequence=input_sequence, + description=input_description, + num_res=num_res) + msa_features = make_msa_features(msas=(msas,), deletion_matrices=(deletion_matrices,)) + + feature_dict = {**sequence_features, **msa_features, **templates_result.features} + return feature_dict + + def multimer_feature_generate(self, fasta_paths: list): + """ multimer feature preprocess. + + Args: + fasta_paths: a list path of fasta, each fasta for one chain fasta sequence file + + Return: + multimer_feature: a combined feature for multi_chain protein + + """ + if len(fasta_paths) == 1: + logging.error("get only one fasta, will return monomer feature") + return self.monomer_feature_generate(fasta_paths[0]) + all_chain_features = {} + for id_, fasta_path_ in enumerate(fasta_paths): + chain_feature = self.monomer_feature_generate(fasta_path_) + chain_feature["chain_id"], chain_feature["aatype"], chain_feature["template_aatype"] = \ + convert_monomer_features(str(id_), chain_feature["aatype"], chain_feature["template_aatype"]) + sequence, domain_name, num_alignments, seq_length = \ + convert_unnecessary_leading_dim_feats(chain_feature["sequence"], chain_feature["domain_name"], + chain_feature["num_alignments"], chain_feature["seq_length"]) + chain_feature["sequence"] = sequence + chain_feature["domain_name"] = domain_name + chain_feature["num_alignments"] = num_alignments + chain_feature["seq_length"] = seq_length + + all_chain_features[str(id_)] = chain_feature + all_chain_features = add_assembly_features(all_chain_features) + combined_features = pair_and_merge(all_chain_features) + combined_features = pad_msa(combined_features, self.msa_length) + + return combined_features diff --git a/MindSPONGE/applications/research/Grasp/data/templates.py b/MindSPONGE/applications/research/Grasp/data/templates.py new file mode 100644 index 0000000000000000000000000000000000000000..aae39ecac4febe48692e5482aa76c1fdb922539f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/templates.py @@ -0,0 +1,920 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''template''' +import datetime +import glob +import os +import re +import dataclasses +from typing import Any, Mapping, Optional, Sequence, Tuple +from absl import logging +import numpy as np + +from mindsponge1.common import residue_constants +from data.kalign import Kalign +from data.parsers import parse_mmcif, parse_a3m + + +class Error(Exception): + """Base class for exceptions.""" + + +class NoChainsError(Error): + """An error indicating that template mmCIF didn't have any chains.""" + + +class SequenceNotInTemplateError(Error): + """An error indicating that template mmCIF didn't contain the sequence.""" + + +class NoAtomDataInTemplateError(Error): + """An error indicating that template mmCIF didn't contain atom positions.""" + + +class TemplateAtomMaskAllZerosError(Error): + """An error indicating that template mmCIF had all atom positions masked.""" + + +class QueryToTemplateAlignError(Error): + """An error indicating that the query can't be aligned to the template.""" + + +class CaDistanceError(Error): + """An error indicating that a CA atom distance exceeds a threshold.""" + + +class MultipleChainsError(Error): + """An error indicating that multiple chains were found for a given ID.""" + + +# Prefilter exceptions. +class PrefilterError(Exception): + """A base class for template prefilter exceptions.""" + + +class DateError(PrefilterError): + """An error indicating that the hit date was after the max allowed date.""" + + +class PdbIdError(PrefilterError): + """An error indicating that the hit PDB ID was identical to the query.""" + + +class AlignRatioError(PrefilterError): + """An error indicating that the hit align ratio to the query was too small.""" + + +class DuplicateError(PrefilterError): + """An error indicating that the hit was an exact subsequence of the query.""" + + +class LengthError(PrefilterError): + """An error indicating that the hit was too short.""" + + +TEMPLATE_FEATURES = { + 'template_aatype': np.float32, + 'template_all_atom_masks': np.float32, + 'template_all_atom_positions': np.float32, + 'template_domain_names': np.object, + 'template_e_value': np.float32, + 'template_neff': np.float32, + 'template_prob_true': np.float32, + 'template_release_date': np.object, + 'template_score': np.float32, + 'template_similarity': np.float32, + 'template_sequence': np.object, + 'template_sum_probs': np.float32, + 'template_confidence_scores': np.int64 +} + + +def _get_pdb_id_and_chain(hit): + """Returns PDB id and chain id for an HHSearch Hit.""" + # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. + id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name) + if not id_match: + raise ValueError(f'hit.name did not start with PDBID_chain: {hit.name}') + pdb_id, chain_id = id_match.group(0).split('_') + return pdb_id.lower(), chain_id + + +def _is_after_cutoff( + pdb_id: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: Optional[datetime.datetime]) -> bool: + """Checks if the template date is after the release date cutoff. + + Args: + pdb_id: 4 letter pdb code. + release_dates: Dictionary mapping PDB ids to their structure release dates. + release_date_cutoff: Max release date that is valid for this query. + + Returns: + True if the template release date is after the cutoff, False otherwise. + """ + if release_date_cutoff is None: + raise ValueError('The release_date_cutoff must not be None.') + if pdb_id in release_dates: + return release_dates[pdb_id] > release_date_cutoff + return False + + +def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: + """Parses the data file from PDB that lists which PDB ids are obsolete.""" + with open(obsolete_file_path) as f: + result = {} + for line in f: + line = line.strip() + # We skip obsolete entries that don't contain a mapping to a new entry. + if line.startswith('OBSLTE') and len(line) > 30: + # Format: Date From To + # 'OBSLTE 31-JUL-94 116L 216L' + from_id = line[20:24].lower() + to_id = line[29:33].lower() + result[from_id] = to_id + return result + + +def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: + """Parses release dates file, returns a mapping from PDBs to release dates.""" + if path.endswith('txt'): + release_dates = {} + with open(path, 'r') as f: + for line in f: + pdb_id, date = line.split(':') + date = date.strip() + # Python 3.6 doesn't have datetime.date.fromisoformat() which is about 90x faster than strptime. + # However, splitting the string manually is about 10x faster than strptime. + release_dates[pdb_id.strip()] = \ + datetime.datetime(year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10])) + return release_dates + raise ValueError('Invalid format of the release date file %s.' % path) + + +def _assess_hhsearch_hit( + hit, + hit_pdb_code, + query_sequence, + query_pdb_code, + release_dates, + release_date_cutoff, + max_subsequence_ratio=0.95, + min_align_ratio=0.1): + """Determines if template is valid (without parsing the template mmcif file). + + Args: + hit: HhrHit for the template. + hit_pdb_code: The 4 letter pdb code of the template hit. This might be + different from the value in the actual hit since the original pdb might + have become obsolete. + query_sequence: Amino acid sequence of the query. + query_pdb_code: 4 letter pdb code of the query. + release_dates: Dictionary mapping pdb codes to their structure release + dates. + release_date_cutoff: Max release date that is valid for this query. + max_subsequence_ratio: Exclude any exact matches with this much overlap. + min_align_ratio: Minimum overlap between the template and query. + + Returns: + True if the hit passed the prefilter. Raises an exception otherwise. + + Raises: + DateError: If the hit date was after the max allowed date. + PdbIdError: If the hit PDB ID was identical to the query. + AlignRatioError: If the hit align ratio to the query was too small. + DuplicateError: If the hit was an exact subsequence of the query. + LengthError: If the hit was too short. + """ + aligned_cols = hit.aligned_cols + align_ratio = aligned_cols / len(query_sequence) + + template_sequence = hit.hit_sequence.replace('-', '') + length_ratio = float(len(template_sequence)) / len(query_sequence) + + # Check whether the template is a large subsequence or duplicate of original + # query. This can happen due to duplicate entries in the PDB database. + duplicate = (template_sequence in query_sequence and length_ratio > max_subsequence_ratio) + if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): + raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date ({release_date_cutoff}).') + + if query_pdb_code is not None: + if query_pdb_code.lower() == hit_pdb_code.lower(): + raise PdbIdError('PDB code identical to Query PDB code.') + + if align_ratio <= min_align_ratio: + raise AlignRatioError(f'Proportion of residues aligned to query too small. Align ratio: {align_ratio}.') + + if duplicate: + raise DuplicateError(f'Template is an exact subsequence of query with large coverage.' + f' Length ratio: {length_ratio}.') + + if len(template_sequence) < 10: + raise LengthError(f'Template too short. Length: {len(template_sequence)}.') + + return True + + +def _find_template_in_pdb(template_chain_id, template_sequence, mmcif_object): + """Tries to find the template chain in the given pdb file. + + This method tries the three following things in order: + 1. Tries if there is an exact match in both the chain ID and the sequence. + If yes, the chain sequence is returned. Otherwise: + 2. Tries if there is an exact match only in the sequence. + If yes, the chain sequence is returned. Otherwise: + 3. Tries if there is a fuzzy match (X = wildcard) in the sequence. + If yes, the chain sequence is returned. + If none of these succeed, a SequenceNotInTemplateError is thrown. + + Args: + template_chain_id: The template chain ID. + template_sequence: The template chain sequence. + mmcif_object: The PDB object to search for the template in. + + Returns: + A tuple with: + * The chain sequence that was found to match the template in the PDB object. + * The ID of the chain that is being returned. + * The offset where the template sequence starts in the chain sequence. + + Raises: + SequenceNotInTemplateError: If no match is found after the steps described + above. + """ + # Try if there is an exact match in both the chain ID and the + # (sub)sequence. + pdb_id = mmcif_object.file_id + chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id) + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found an exact template match %s_%s.', pdb_id, template_chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, template_chain_id, mapping_offset + + # Try if there is an exact match in the (sub)sequence only. + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + if chain_sequence and (template_sequence in chain_sequence): + logging.info(f'Found a sequence-only match {pdb_id}_{chain_id}.') + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, chain_id, mapping_offset + + # Return a chain sequence that fuzzy matches (X = wildcard) the template. + # Make parentheses unnamed groups (?:_) to avoid the 100 named groups + # limit. + regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence] + regex = re.compile(''.join(regex)) + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + match = re.search(regex, chain_sequence) + if match: + logging.info(f'Found a fuzzy sequence-only match {pdb_id}_{chain_id}.') + mapping_offset = match.start() + return chain_sequence, chain_id, mapping_offset + + # No hits, raise an error. + raise SequenceNotInTemplateError( + 'Could not find the template sequence in %s_%s. Template sequence: %s, ' + 'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence, mmcif_object.chain_to_seqres)) + + +def _realign_pdb_template_to_query( + old_template_sequence, + template_chain_id, + mmcif_object, + old_mapping, + kalign_binary_path): + """Aligns template from the mmcif_object to the query. + + In case PDB70 contains a different version of the template sequence, we need + to perform a realignment to the actual sequence that is in the mmCIF file. + This method performs such realignment, but returns the new sequence and + mapping only if the sequence in the mmCIF file is 90% identical to the old + sequence. + + Note that the old_template_sequence comes from the hit, and contains only that + part of the chain that matches with the query while the new_template_sequence + is the full chain. + + Args: + old_template_sequence: The template sequence that was returned by the PDB + template search (typically done using HHSearch). + template_chain_id: The template chain id was returned by the PDB template + search (typically done using HHSearch). This is used to find the right + chain in the mmcif_object chain_to_seqres mapping. + mmcif_object: A mmcif_object which holds the actual template data. + old_mapping: A mapping from the query sequence to the template sequence. + This mapping will be used to compute the new mapping from the query + sequence to the actual mmcif_object template sequence by aligning the + old_template_sequence and the actual template sequence. + kalign_binary_path: The path to a kalign executable. + + Returns: + A tuple (new_template_sequence, new_query_to_template_mapping) where: + * new_template_sequence is the actual template sequence that was found in + the mmcif_object. + * new_query_to_template_mapping is the new mapping from the query to the + actual template found in the mmcif_object. + + Raises: + QueryToTemplateAlignError: + * If there was an error thrown by the alignment tool. + * Or if the actual template sequence differs by more than 10% from the + old_template_sequence. + """ + aligner = Kalign(binary_path=kalign_binary_path) + new_template_sequence = mmcif_object.chain_to_seqres.get(template_chain_id, '') + + # Sometimes the template chain id is unknown. But if there is only a single + # sequence within the mmcif_object, it is safe to assume it is that one. + if not new_template_sequence: + if len(mmcif_object.chain_to_seqres) == 1: + logging.info(f'Could not find {template_chain_id} in {mmcif_object.file_id}, but there is only 1 sequence,' + f' so using that one.') + new_template_sequence = list(mmcif_object.chain_to_seqres.values())[0] + else: + raise QueryToTemplateAlignError( + f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. ' + 'If there are no mmCIF parsing errors, it is possible it was not a ' + 'protein chain.') + + try: + (old_aligned_template, new_aligned_template), _ = \ + parse_a3m(aligner.align([old_template_sequence, new_template_sequence])) + except Exception as e: + raise QueryToTemplateAlignError( + 'Could not align old template %s to template %s (%s_%s). Error: %s' % + (old_template_sequence, + new_template_sequence, + mmcif_object.file_id, + template_chain_id, + str(e))) + + logging.info(f'Old aligned template: {old_aligned_template}\nNew aligned template: {new_aligned_template}') + + old_to_new_template_mapping = {} + old_template_index = -1 + new_template_index = -1 + num_same = 0 + for old_template_aa, new_template_aa in zip(old_aligned_template, new_aligned_template): + if old_template_aa != '-': + old_template_index += 1 + if new_template_aa != '-': + new_template_index += 1 + if old_template_aa != '-' and new_template_aa != '-': + old_to_new_template_mapping[old_template_index] = new_template_index + if old_template_aa == new_template_aa: + num_same += 1 + + # Require at least 90 % sequence identity wrt to the shorter of the sequences. + if float(num_same) / min(len(old_template_sequence), len(new_template_sequence)) < 0.9: + raise QueryToTemplateAlignError( + 'Insufficient similarity of the sequence in the database: %s to the ' + 'actual sequence in the mmCIF file %s_%s: %s. We require at least ' + '90 %% similarity wrt to the shorter of the sequences. This is not a ' + 'problem unless you think this is a template that should be included.' % + (old_template_sequence, mmcif_object.file_id, template_chain_id, + new_template_sequence)) + + new_query_to_template_mapping = {} + for query_index, old_template_index in old_mapping.items(): + new_query_to_template_mapping[query_index] = (old_to_new_template_mapping.get(old_template_index, -1)) + + new_template_sequence = new_template_sequence.replace('-', '') + + return new_template_sequence, new_query_to_template_mapping + + +def _check_residue_distances(all_positions: np.ndarray, + all_positions_mask: np.ndarray, + max_ca_ca_distance: float): + """Checks if the distance between unmasked neighbor residues is ok.""" + ca_position = residue_constants.atom_order['CA'] + prev_is_unmasked = False + prev_calpha = None + for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)): + this_is_unmasked = bool(mask[ca_position]) + if this_is_unmasked: + this_calpha = coords[ca_position] + if prev_is_unmasked: + distance = np.linalg.norm(this_calpha - prev_calpha) + if distance > max_ca_ca_distance: + raise CaDistanceError('The distance between residues %d and %d is %f > limit %f.' % + (i, i + 1, distance, max_ca_ca_distance)) + prev_calpha = this_calpha + prev_is_unmasked = this_is_unmasked + + +def _get_atom_positions( + mmcif_object, + auth_chain_id, + max_ca_ca_distance) -> Tuple[np.ndarray, np.ndarray]: + """Gets atom positions and mask from a list of Biopython Residues.""" + num_res = len(mmcif_object.chain_to_seqres[auth_chain_id]) + + relevant_chains = [c for c in mmcif_object.structure.get_chains() if c.id == auth_chain_id] + if len(relevant_chains) != 1: + raise MultipleChainsError(f'Expected exactly one chain in structure with id {auth_chain_id}.') + chain = relevant_chains[0] + + all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3]) + all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num], dtype=np.int64) + for res_index in range(num_res): + pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) + mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) + res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index] + if not res_at_position.is_missing: + res = chain[(res_at_position.hetflag, + res_at_position.position.residue_number, + res_at_position.position.insertion_code)] + for atom in res.get_atoms(): + atom_name = atom.get_name() + x, y, z = atom.get_coord() + if atom_name in residue_constants.atom_order.keys(): + pos[residue_constants.atom_order[atom_name]] = [x, y, z] + mask[residue_constants.atom_order[atom_name]] = 1.0 + elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE': + # Put the coordinates of the selenium atom in the sulphur + # column. + pos[residue_constants.atom_order['SD']] = [x, y, z] + mask[residue_constants.atom_order['SD']] = 1.0 + + all_positions[res_index] = pos + all_positions_mask[res_index] = mask + _check_residue_distances(all_positions, all_positions_mask, max_ca_ca_distance) + return all_positions, all_positions_mask + + +def _extract_template_features( + mmcif_object, + pdb_id, + mapping, + template_sequence, + query_sequence, + template_chain_id, + confidence_scores, + kalign_binary_path): + """Parses atom positions in the target structure and aligns with the query. + + Atoms for each residue in the template structure are indexed to coincide + with their corresponding residue in the query sequence, according to the + alignment mapping provided. + + Note that we only extract at most 500 templates because of HHSearch settings. + + We set missing/invalid confidence scores to the default value of -1. + Note: We now have 4 types of confidence scores: + 1. Valid scores + 2. Invalid scores of residues not in both the query sequence and template + sequence + 3. Missing scores because we don't have the secondary structure, and HHAlign + doesn't produce the posterior probabilities in this case. + 4. Missing scores because of a different template sequence in PDB70, + invalidating the previously computed confidence scores. (Though in theory + HHAlign can be run on these to recompute the correct confidence scores). + We handle invalid and missing scores by setting them to -1, but consider + adding masks for the different types. + + Args: + mmcif_object: mmcif_parsing.MmcifObject representing the template. + pdb_id: PDB code for the template. + mapping: Dictionary mapping indices in the query sequence to indices in + the template sequence. + template_sequence: String describing the amino acid sequence for the + template protein. + query_sequence: String describing the amino acid sequence for the query + protein. + template_chain_id: String ID describing which chain in the structure proto + should be used. + confidence_scores: String containing per-residue confidence scores, where + each character represents the *TRUNCATED* posterior probability that the + corresponding template residue is correctly aligned with the query + residue, given the database match is correct (0 corresponds approximately + to 0-10%, 9 to 90-100%). + kalign_binary_path: The path to a kalign executable used for template + realignment. + + Returns: + A tuple with: + * A dictionary containing the extra features derived from the template + protein structure. + * A warning message if the hit was realigned to the actual mmCIF sequence. + Otherwise None. + + Raises: + NoChainsError: If the mmcif object doesn't contain any chains. + SequenceNotInTemplateError: If the given chain id / sequence can't + be found in the mmcif object. + QueryToTemplateAlignError: If the actual template in the mmCIF file + can't be aligned to the query. + NoAtomDataInTemplateError: If the mmcif object doesn't contain + atom positions. + TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any + unmasked residues. + """ + if mmcif_object is None or not mmcif_object.chain_to_seqres: + raise NoChainsError('No chains in PDB: %s_%s' % (pdb_id, template_chain_id)) + + warning = None + try: + seqres, chain_id, mapping_offset = _find_template_in_pdb( + template_chain_id=template_chain_id, + template_sequence=template_sequence, + mmcif_object=mmcif_object) + except SequenceNotInTemplateError: + # If PDB70 contains a different version of the template, we use the sequence + # from the mmcif_object. + chain_id = template_chain_id + warning = (f'The exact sequence {template_sequence} was not found in ' + f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.') + logging.warning(warning) + # This throws an exception if it fails to realign the hit. + seqres, mapping = _realign_pdb_template_to_query( + old_template_sequence=template_sequence, + template_chain_id=template_chain_id, + mmcif_object=mmcif_object, + old_mapping=mapping, + kalign_binary_path=kalign_binary_path) + logging.info(f'Sequence in {pdb_id}_{chain_id}: {template_sequence} successfully realigned to {seqres}') + # The template sequence changed. + template_sequence = seqres + # No mapping offset, the query is aligned to the actual sequence. + mapping_offset = 0 + # Confidence scores were based on the previous sequence, so they are + # invalid + confidence_scores = None + + try: + # Essentially set to infinity - we don't want to reject templates unless + # they're really really bad. + all_atom_positions, all_atom_mask = _get_atom_positions(mmcif_object, chain_id, max_ca_ca_distance=150.0) + except (CaDistanceError, KeyError) as ex: + raise NoAtomDataInTemplateError(f'Could not get atom data ({pdb_id}_{chain_id}): {str(ex)}') + + all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0]) + all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) + + output_templates_sequence = [] + output_confidence_scores = [] + templates_all_atom_positions = [] + templates_all_atom_masks = [] + + for _ in query_sequence: + # Residues in the query_sequence that are not in the template_sequence: + templates_all_atom_positions.append(np.zeros((residue_constants.atom_type_num, 3))) + templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num)) + output_templates_sequence.append('-') + output_confidence_scores.append(-1) + + for k, v in mapping.items(): + template_index = v + mapping_offset + templates_all_atom_positions[k] = all_atom_positions[template_index][0] + templates_all_atom_masks[k] = all_atom_masks[template_index][0] + output_templates_sequence[k] = template_sequence[v] + if confidence_scores and confidence_scores[v] != ' ': + output_confidence_scores[k] = int(confidence_scores[v]) + + # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, + # O). + if np.sum(templates_all_atom_masks) < 5: + raise TemplateAtomMaskAllZerosError('Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' % + (pdb_id, chain_id, min(mapping.values()) + mapping_offset, + max(mapping.values()) + mapping_offset)) + + output_templates_sequence = ''.join(output_templates_sequence) + + templates_aatype = residue_constants.sequence_to_onehot( + output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) + + return ( + {'template_all_atom_positions': np.array(templates_all_atom_positions), + 'template_all_atom_masks': np.array(templates_all_atom_masks), + 'template_sequence': output_templates_sequence.encode(), + 'template_aatype': np.array(templates_aatype), + 'template_confidence_scores': np.array(output_confidence_scores), + 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), + 'template_release_date': mmcif_object.header['release_date'].encode()}, + warning) + + +def _build_query_to_hit_index_mapping( + hit_query_sequence: str, + hit_sequence: str, + indices_hit: Sequence[int], + indices_query: Sequence[int], + original_query_sequence: str) -> Mapping[int, int]: + """Gets mapping from indices in original query sequence to indices in the hit. + + hit_query_sequence and hit_sequence are two aligned sequences containing gap + characters. hit_query_sequence contains only the part of the original query + sequence that matched the hit. When interpreting the indices from the .hhr, we + need to correct for this to recover a mapping from original query sequence to + the hit sequence. + + Args: + hit_query_sequence: The portion of the query sequence that is in the .hhr + hit + hit_sequence: The portion of the hit sequence that is in the .hhr + indices_hit: The indices for each aminoacid relative to the hit sequence + indices_query: The indices for each aminoacid relative to the original query + sequence + original_query_sequence: String describing the original query sequence. + + Returns: + Dictionary with indices in the original query sequence as keys and indices + in the hit sequence as values. + """ + # If the hit is empty (no aligned residues), return empty mapping + if not hit_query_sequence: + return {} + + # Remove gaps and find the offset of hit.query relative to original query. + hhsearch_query_sequence = hit_query_sequence.replace('-', '') + hit_sequence = hit_sequence.replace('-', '') + hhsearch_query_offset = original_query_sequence.find(hhsearch_query_sequence) + + # Index of -1 used for gap characters. Subtract the min index ignoring + # gaps. + min_idx = min(x for x in indices_hit if x > -1) + fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit] + + min_idx = min(x for x in indices_query if x > -1) + fixed_indices_query = [x - min_idx if x > - 1 else - 1 for x in indices_query] + + # Zip the corrected indices, ignore case where both seqs have gap + # characters. + mapping = {} + for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit): + if q_t != -1 and q_i != -1: + if (q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(original_query_sequence)): + continue + mapping[q_i + hhsearch_query_offset] = q_t + + return mapping + + +@dataclasses.dataclass(frozen=True) +class SingleHitResult: + features: Optional[Mapping[str, Any]] + error: Optional[str] + warning: Optional[str] + + +def _process_single_hit( + query_sequence, + query_pdb_code, + hit, + mmcif_dir, + max_template_date, + release_dates, + obsolete_pdbs, + kalign_binary_path, + strict_error_check): + """Tries to extract template features from a single HHSearch hit.""" + # Fail hard if we can't get the PDB ID and chain name from the hit. + hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) + + if hit_pdb_code not in release_dates: + if hit_pdb_code in obsolete_pdbs: + hit_pdb_code = obsolete_pdbs[hit_pdb_code] + + # Pass hit_pdb_code since it might have changed due to the pdb being + # obsolete. + try: + _assess_hhsearch_hit( + hit=hit, + hit_pdb_code=hit_pdb_code, + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + release_dates=release_dates, + release_date_cutoff=max_template_date) + except PrefilterError as e: + msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' + logging.info('%s: %s', query_pdb_code, msg) + if strict_error_check and isinstance(e, (DateError, PdbIdError, DuplicateError)): + # In strict mode we treat some prefilter cases as errors. + return SingleHitResult(features=None, error=msg, warning=None) + + return SingleHitResult(features=None, error=None, warning=None) + + mapping = _build_query_to_hit_index_mapping( + hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query, query_sequence) + + # The mapping is from the query to the actual hit sequence, so we need to + # remove gaps (which regardless have a missing confidence score). + template_sequence = hit.hit_sequence.replace('-', '') + confidence_scores = ''.join([cs for t, cs in zip(hit.hit_sequence, hit.confidence_scores) if t != '-']) + + cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') + if not os.path.exists(cif_path): + cif_path = os.path.join(mmcif_dir, hit_pdb_code.upper() + '.cif') + logging.info('Reading PDB entry from %s. Query: %s, template: %s', cif_path, query_sequence, template_sequence) + # Fail if we can't find the mmCIF file. + with open(cif_path, 'r') as cif_file: + cif_string = cif_file.read() + + parsing_result = parse_mmcif(file_id=hit_pdb_code, mmcif_string=cif_string) + + if parsing_result.mmcif_object is not None: + hit_release_date = datetime.datetime.strptime(parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d') + if hit_release_date > max_template_date: + error = ('Template %s date (%s) > max template date (%s).' % + (hit_pdb_code, hit_release_date, max_template_date)) + if strict_error_check: + return SingleHitResult(features=None, error=error, warning=None) + logging.warning(error) + return SingleHitResult(features=None, error=None, warning=None) + + try: + features, realign_warning = _extract_template_features( + mmcif_object=parsing_result.mmcif_object, + pdb_id=hit_pdb_code, + mapping=mapping, + template_sequence=template_sequence, + query_sequence=query_sequence, + template_chain_id=hit_chain_id, + confidence_scores=confidence_scores, + kalign_binary_path=kalign_binary_path) + features['template_e_value'] = [hit.e_value] + features['template_sum_probs'] = [hit.sum_probs] + features['template_prob_true'] = [hit.prob_true] + features['template_score'] = [hit.score] + features['template_neff'] = [hit.neff] + features['template_similarity'] = [hit.similarity] + + # It is possible there were some errors when parsing the other chains in the + # mmCIF file, but the template features for the chain we want were still + # computed. In such case the mmCIF parsing errors are not relevant. + return SingleHitResult(features=features, error=None, warning=realign_warning) + except (NoChainsError, NoAtomDataInTemplateError, + TemplateAtomMaskAllZerosError) as e: + # These 3 errors indicate missing mmCIF experimental data rather than a + # problem with the template search, so turn them into warnings. + warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % (hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors)) + if strict_error_check: + return SingleHitResult(features=None, error=warning, warning=None) + return SingleHitResult(features=None, error=None, warning=warning) + except Error as e: + error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % (hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors)) + return SingleHitResult(features=None, error=error, warning=None) + + +@dataclasses.dataclass(frozen=True) +class TemplateSearchResult: + features: Mapping[str, Any] + errors: Sequence[str] + warnings: Sequence[str] + + +class TemplateHitFeaturizer: + """A class for turning hhr hits to template features.""" + + def __init__( + self, + mmcif_dir: str, + max_template_date: str, + max_hits: int, + kalign_binary_path: str, + release_dates_path: Optional[str], + obsolete_pdbs_path: Optional[str], + strict_error_check: bool = False): + """Initializes the Template Search. + + Args: + mmcif_dir: Path to a directory with mmCIF structures. Once a template ID + is found by HHSearch, this directory is used to retrieve the template + data. + max_template_date: The maximum date permitted for template structures. No + template with date higher than this date will be returned. In ISO8601 + date format, YYYY-MM-DD. + max_hits: The maximum number of templates that will be returned. + kalign_binary_path: The path to a kalign executable used for template + realignment. + release_dates_path: An optional path to a file with a mapping from PDB IDs + to their release dates. Thanks to this we don't have to redundantly + parse mmCIF files to get that information. + obsolete_pdbs_path: An optional path to a file containing a mapping from + obsolete PDB IDs to the PDB IDs of their replacements. + strict_error_check: If True, then the following will be treated as errors: + * If any template date is after the max_template_date. + * If any template has identical PDB ID to the query. + * If any template is a duplicate of the query. + * Any feature computation errors. + """ + self._mmcif_dir = mmcif_dir + if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')): + logging.error('Could not find CIFs in %s', self._mmcif_dir) + raise ValueError(f'Could not find CIFs in {self._mmcif_dir}') + + try: + self._max_template_date = datetime.datetime.strptime(max_template_date, '%Y-%m-%d') + except ValueError: + raise ValueError('max_template_date must be set and have format YYYY-MM-DD.') + self._max_hits = max_hits + self._kalign_binary_path = kalign_binary_path + self._strict_error_check = strict_error_check + + if release_dates_path: + logging.info('Using precomputed release dates %s.', release_dates_path) + self._release_dates = _parse_release_dates(release_dates_path) + else: + self._release_dates = {} + + if obsolete_pdbs_path: + logging.info('Using precomputed obsolete pdbs %s.', obsolete_pdbs_path) + self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path) + else: + self._obsolete_pdbs = {} + + def get_templates( + self, + query_sequence, + query_pdb_code, + query_release_date, + hhr_hits): + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_pdb_code) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + # Always use a max_template_date. Set to query_release_date minus 60 days + # if that's earlier. + template_cutoff_date = self._max_template_date + if query_release_date: + delta = datetime.timedelta(days=60) + if query_release_date - delta < template_cutoff_date: + template_cutoff_date = query_release_date - delta + assert template_cutoff_date < query_release_date + assert template_cutoff_date <= self._max_template_date + + num_hits = 0 + errors = [] + warnings = [] + + for hit in sorted(hhr_hits, key=lambda x: x.sum_probs, reverse=True): + # We got all the templates we wanted, stop processing HHSearch + # hits. + if num_hits >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=template_cutoff_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.info('Skipped invalid hit %s, error: %s, warning: %s', hit.name, result.error, result.warning) + else: + # Increment the hit counter, since we got features out of this + # hit. + num_hits += 1 + for k in template_features: + template_features.get(k).append(result.features[k]) + + for name in template_features: + if num_hits > 0: + template_features[name] = np.stack(template_features.get(name), + axis=0).astype(TEMPLATE_FEATURES.get(name)) + else: + # Make sure the feature has correct dtype even if empty. + template_features[name] = np.array([], dtype=TEMPLATE_FEATURES.get(name)) + + return TemplateSearchResult(features=template_features, errors=errors, warnings=warnings) diff --git a/MindSPONGE/applications/research/Grasp/data/utils.py b/MindSPONGE/applications/research/Grasp/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a5f4ce9d19b37ce0164f9408c1a3cd17d374f43e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/data/utils.py @@ -0,0 +1,188 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +utils module used for tmpdir generation. +""" +import time +import contextlib +import tempfile +import shutil +import pickle +import os +import gzip +import numpy as np +from absl import logging +from scipy import sparse as sp + +from .parsers import parse_fasta + +truncated_normal_stddev_factor = np.asarray(.87962566103423978, dtype=np.float32) + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: str): + """Context manager that deletes a temporary directory on exit. + for example: + with tmpdir_manager(base_dir='/tmp') as tmp_dir: + test_file = os.path.join(tmp_dir, 'input.fasta') + with open(test_file, "w") as f: + f.write("this is a test. \n") + print("exit") + this would create a tmp data directory and when finished the main process of writing "this is a test. \n" into + the tmp file,(after print "exit"), the system would destroy the previous tmp dir + """ + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info('Started %s', msg) + tic = time.time() + yield + toc = time.time() + logging.info('Finished %s in %.3f seconds', msg, toc - tic) + + +def get_raw_feature(input_path, feature_generator, use_pkl): + '''get raw feature of protein by loading pkl file or searching from database''' + if use_pkl: + f = open(input_path, "rb") + data = pickle.load(f) + f.close() + return data + return feature_generator.monomer_feature_generate(input_path) + + +def get_crop_size(input_path, use_pkl): + '''get crop size of sequence by comparing all input sequences\' length''' + filenames = os.listdir(input_path) + max_length = 0 + for filename in filenames: + file_full_path = os.path.join(input_path, filename) + if use_pkl: + with open(file_full_path, "rb") as f: + data = pickle.load(f) + current_crop_size = (data["msa"].shape[1] // 256 + 1) * 256 + max_length = max(max_length, current_crop_size) + else: + with open(file_full_path, "r") as f: + input_fasta_str = f.read() + input_seqs, _ = parse_fasta(input_fasta_str) + current_crop_size = (len(input_seqs[0]) // 256 + 1) * 256 + max_length = max(max_length, current_crop_size) + + return max_length + + +# def load_pickle(path): +# def load(path): +# assert path.endswith(".pkl") or path.endswith( +# ".pkl.gz" +# ), f"bad suffix in {path} as pickle file." +# open_fn = gzip.open if path.endswith(".gz") else open +# with open_fn(path, "rb") as f: +# return pickle.load(f) + +# ret = load(path) +# ret = uncompress_features(ret) +# return ret + + +# def uncompress_features(feats): +# if "sparse_deletion_matrix_int" in feats: +# v = feats.pop("sparse_deletion_matrix_int") +# v = to_dense_matrix(v) +# feats["deletion_matrix"] = v +# return feats + + +# def to_dense_matrix(spmat_dict): +# spmat = sp.coo_matrix( +# (spmat_dict["data"], (spmat_dict["row"], spmat_dict["col"])), +# shape=spmat_dict["shape"], +# dtype=np.float32, +# ) +# return spmat.toarray() + + +def str_hash(text): + hash=0 + for ch in text: + hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF + return hash + + +@contextlib.contextmanager +def numpy_seed(seed, *addl_seeds, key=None): + """Context manager which seeds the NumPy PRNG with the specified seed and + restores the state afterward""" + if seed is None: + yield + return + def check_seed(s): + assert type(s) == int or type(s) == np.int32 or type(s) == np.int64 + check_seed(seed) + if len(addl_seeds) > 0: + for s in addl_seeds: + check_seed(s) + seed = int(hash((seed, *addl_seeds)) % 1e8) + if key is not None: + seed = int(hash((seed, str_hash(key))) % 1e8) + state = np.random.get_state() + np.random.seed(seed) + # np.random.seed(123) + try: + yield + finally: + np.random.set_state(state) + + +# def batch_by_size( +# indices, +# batch_size=None, +# required_batch_size_multiple=1, +# ): +# """ +# Yield mini-batches of indices bucketed by size. Batches may contain +# sequences of different lengths. + +# Args: +# indices (List[int]): ordered list of dataset indices +# batch_size (int, optional): max number of sentences in each +# batch (default: None). +# required_batch_size_multiple (int, optional): require batch size to +# be less than N or a multiple of N (default: 1). +# """ + +# batch_size = batch_size if batch_size is not None else 1 +# bsz_mult = required_batch_size_multiple + +# step = ((batch_size + bsz_mult - 1) // bsz_mult) * bsz_mult + +# if not isinstance(indices, np.ndarray): +# indices = np.fromiter(indices, dtype=np.int64, count=-1) + +# num_batches = (len(indices) + step - 1) // step +# steps = np.arange(num_batches - 1) + 1 +# steps *= step +# batch_indices = np.split(indices, steps) +# assert len(batch_indices) == num_batches +# # validation or test data size is smaller than a mini-batch size in some downstream tasks. +# assert batch_indices[0].shape[0] <= step +# return batch_indices diff --git a/MindSPONGE/applications/research/Grasp/infer_main.py b/MindSPONGE/applications/research/Grasp/infer_main.py new file mode 100644 index 0000000000000000000000000000000000000000..14eb0bd42a771671193c7a25f6d77b67fdf38d44 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/infer_main.py @@ -0,0 +1,88 @@ +import argparse +import os +import glob +import time +import datetime +import pickle +import pandas as pd +import numpy as np +from restraint_sample import BINS + +import mindspore +from mindspore import context +import mindspore.communication as D +from mindspore import Tensor, ops + + +parser = argparse.ArgumentParser(description='Inputs for eval.py') +parser.add_argument('--raw_feat', default='./grasp2/features.pkl', help='Location of raw features pickle input') #/job/dataset/csp/raw_feat/5JDS.pkl './examples_7STZ_2152/features.pkl' './5JDS.pkl' './6HTX.pkl' 'T0001_features.pkl' ./grasp/features.pkl +parser.add_argument('--output_dir', default='./compare_with_parallel', help='Output directory for predictions') #/job/output/test +parser.add_argument('--restr', default="./grasp2/restr_5perc.pkl", help='Location of restraints pickle input, if not provided, will infer without restraints') # ./grasp2/restr_5perc.pkl +parser.add_argument('--ckpt_path', default="./step_14000.ckpt", help='ckpt path')#/job/output/ckpt_dir/ft-grasp-v11-64/step_8000.ckpt params_model_1_multimer_v3_ms.ckpt +parser.add_argument('--data_config', default="./config/data-infer.yaml", help='data process config') # ./config/data-infer.yaml +parser.add_argument('--model_config', default="./config/model-infer.yaml", help='model config') # ./config/model-infer.yaml +parser.add_argument('--seq_len', default=8192, type=int) # sequence will be padded to this length 256 +parser.add_argument('--mixed_precision', default=1, type=int) +parser.add_argument('--multimer', default=1, type=int) +parser.add_argument('--device_num', default=8, type=int) +parser.add_argument('--iter', default=5, type=int) +parser.add_argument('--num_recycle', default=20, type=int) + + + +arguments = parser.parse_args() +# context.set_context(device_target="Ascend", device_id=6) +# context.set_context(device_target="Ascend", device_id=7, mode=mindspore.GRAPH_MODE, save_graphs=1, save_graphs_path='./compare_with_parallel/single_graphs/') #, save_graphs=1, save_graphs_path='./compare_with_parallel/single_graphs/' +# from utils_infer_single import infer_config, infer_batch, DataGenerator, ModelGenerator, grasp_infer + +context.set_context(device_target="Ascend", + mode=mindspore.GRAPH_MODE, + max_call_depth=24000, + max_device_memory='58GB', + # save_graphs=True, + # save_graphs_path='./compare_with_parallel/graphs/' + # save_graphs=True + # memory_optimize_level="O1", + # jit_syntax_level=0 + # variable_memory_max_size="30GB" + # save_graphs=1, save_graphs_path='./compare_with_parallel/graphs_25/' + )#, save_graphs=1, save_graphs_path='./compare_with_parallel/graphs/', save_graphs=1, save_graphs_path='./compare_with_parallel/graphs_24/', jit_config={"jit_level": "O0"} , memory_optimize_level="O1", jit_syntax_level=0 +split_rank = arguments.device_num +data_strategy=((split_rank,),(split_rank,),(1,split_rank),(1,split_rank,1),(1,split_rank,1,1), + (split_rank,),(split_rank,),(split_rank,),(split_rank,),(1,split_rank), + (split_rank,1),(1, split_rank, 1),(1,split_rank),(1,split_rank),(1,split_rank), + (split_rank,1),(split_rank,1),(split_rank,1,1), (split_rank, 1), (split_rank,) ,(split_rank,1,1),(split_rank,1),(1,split_rank,1)) +# data_strategy=((1, split_rank, 1),) +# data_strategy=((split_rank,1,1),(split_rank,1,1), (1, split_rank), (split_rank, 1, 1), (split_rank, 1)) +mindspore.set_auto_parallel_context(device_num=split_rank, parallel_mode=mindspore.ParallelMode.SEMI_AUTO_PARALLEL, dataset_strategy=data_strategy, enable_alltoall=False) # 数据集按数据并行的方式切分,且shard的输出张量也按数据并行方式切分, search_mode="sharding_propagation", +D.init() +from utils_infer import infer_config, infer_batch, DataGenerator, ModelGenerator, grasp_infer + + +# print(arguments) +model_gen = ModelGenerator(arguments, arguments.ckpt_path) + +with open(arguments.raw_feat, 'rb') as f: + raw_feature = pickle.load(f) + +restr = None +if arguments.restr != "None": + with open(arguments.restr, 'rb') as f: + restr = pickle.load(f) + +print("debug raw_feat keys", raw_feature.keys()) +t1 = time.time() +grasp_infer(model_gen=model_gen, + ckpt_id=8000, + raw_feature=raw_feature, + restraints=restr, + output_prefix=f'{arguments.output_dir}/test6_{arguments.seq_len}', + iter=arguments.iter, + seed=9, + num_recycle=arguments.num_recycle, + device_num=arguments.device_num + ) + +t2 = time.time() +print("Inference done!") +print("time cost: ", t2 - t1) \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/infer_main_parallel.sh b/MindSPONGE/applications/research/Grasp/infer_main_parallel.sh new file mode 100644 index 0000000000000000000000000000000000000000..d72884471e007eb44c2fe2a3bc3277ced797484f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/infer_main_parallel.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +input="$1" + +count=$(echo "$input" | tr ',' '\n' | grep -c '[0-9]') + +IFS=';' read -r -a input5 <<< $3 + +raw_feat=${input5[0]} +restr=${input5[1]:-None} +ckpt_path=${input5[2]} +iter=${input5[3]} +num_recycle=${input5[4]} + +export MS_ASCEND_CHECK_OVERFLOW_MODE=SATURATION_MODE +# export MS_MEMORY_STATISTIC=1 +# export MS_KERNEL_LAUNCH_SKIP=all +export ASCEND_RT_VISIBLE_DEVICES=$input +export HCCL_CONNECT_TIMEOUT=6000 +# export MS_ALLOC_CONF="memory_tracker:True" +# export MS_DEV_DUMP_IR_PASSES="hwopt_d_after_stream_assign,valid,graph_build" + +# export GLOG_v=2 +# export MS_DEV_DUMP_IR_PASSES="step_parallel,validate,stream" +# export GRAPH_OP_RUN=1 +#export MS_DEV_DDE_ONLY_MARK=1 +# export MINDSPORE_DUMP_CONFIG=/autotest/protein/mindscience/MindSPONGE/applications/MEGAProtein/dump_af.json + +ulimit -u unlimited +ulimit -s 102400 +ulimit -SHn 65535 +mpirun -n $count --output-filename ./log_distribute2 --merge-stderr-to-stdout --allow-run-as-root python infer_main.py --seq_len $2 --raw_feat $raw_feat --restr $restr --ckpt_path $ckpt_path --iter $iter --num_recycle $num_recycle --device_num $count > ./log_distribute2/test_distribute_log 2>&1 \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0e019414ba83e33b2d6d0f58fc82d0609928b4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MindSPONGE""" diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/callback/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/callback/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79b064f61817f0fad2e870dc88a9dd4996d32d80 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/callback/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""callback""" + +from .h5md import WriteH5MD +from .information import RunInfo + +__all__ = ['WriteH5MD', 'RunInfo'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/callback/h5md.py b/MindSPONGE/applications/research/Grasp/mindsponge1/callback/h5md.py new file mode 100644 index 0000000000000000000000000000000000000000..95e7dcfbcaef8d08d726ee4fea9c343d9a617409 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/callback/h5md.py @@ -0,0 +1,261 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Callback to write H5MD trajectory file +""" + +from mindspore.train.callback import Callback, RunContext + +from ..system import Molecule +from ..optimizer import Updater +from ..data.export import H5MD + + +class WriteH5MD(Callback): + r""" + Callback to write HDF5 molecular data (H5MD) file. + + Args: + system (Molecule): Simulation system + filename (str): Name of output H5MD file. + save_freq(int): Saved frequency. Default: 1 + directory (str): Directory of the output file. Default: None + write_velocity (bool): Whether to write the velocity of the system to the H5MD file. + Default: False + write_force (bool): Whether to write the forece of the system to the H5MD file. + Default: False + write_image (bool): Whether to write the image of the position of system to the H5MD file. + Default: False + length_unit (str): Length unit for coordinates. Default: None. + energy_unit (str): Energy unit. Default: None. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + system: Molecule, + filename: str, + save_freq: int = 1, + directory: str = None, + write_velocity: bool = False, + write_force: bool = False, + write_image: bool = True, + length_unit: str = None, + energy_unit: str = None, + ): + + self.system = system + self.units = system.units + self.h5md = H5MD(self.system, filename, directory, + length_unit, energy_unit) + + self.use_pbc = system.pbc_box is not None + self.const_volume = True + + self.write_image = write_image + if self.use_pbc and self.write_image: + self.h5md.set_image() + + self.save_freq = save_freq + + self.write_velocity = write_velocity + if self.write_velocity: + self.h5md.set_velocity() + + self.write_force = write_force + if self.write_force: + self.h5md.set_force() + + self.potential = 0 + self.kinetics = 0 + self.tot_energy = 0 + self.temperature = 0 + self.pressure = 0 + self.volume = 0 + + self.observables = [ + 'potential_energy', + 'kinetic_energy', + 'total_energy', + 'temperature', + 'pressure', + 'volume', + ] + + self.obs_units = [ + self.units.energy_unit_name, + self.units.energy_unit_name, + self.units.energy_unit_name, + 'K', + 'bar', + self.units.volume_unit_name, + ] + + self.obs_dtypes = [ + 'float32', + 'float32', + 'float32', + 'float32', + 'float32', + 'float32', + ] + + self.obs_shapes = [ + (), + (), + (), + (), + (), + (), + ] + + self.h5md.set_observables( + self.observables, self.obs_shapes, self.obs_dtypes, self.obs_units) + + self.use_updater = None + + self.count = 0 + + def __enter__(self): + """Return the enter target.""" + return self + + def __exit__(self, *err): + """Release resources here if have any.""" + + def begin(self, run_context: RunContext): + """ + Called once before the network executing. + + Args: + run_context (RunContext): Include some information of the model. + """ + + cb_params = run_context.original_args() + if isinstance(cb_params.optimizer, Updater): + self.use_updater = True + if self.use_pbc: + self.const_volume = cb_params.barostat is None + self.h5md.set_box(self.const_volume) + else: + self.use_updater = False + if self.use_pbc: + self.h5md.set_box(True) + if self.write_velocity and not isinstance(cb_params.optimizer, Updater): + self.write_velocity = False + print('Warning! The optimizer "'+str(cb_params.optimizer) + + '" does not has the attribute "velocity".') + + def epoch_begin(self, run_context: RunContext): + """ + Called before each epoch beginning. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def epoch_end(self, run_context: RunContext): + """ + Called after each epoch finished. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def step_begin(self, run_context: RunContext): + """ + Called before each step beginning. + + Args: + run_context (RunContext): Include some information of the model. + """ + + if self.count % self.save_freq == 0: + cb_params = run_context.original_args() + if self.use_updater: + self.kinetics = cb_params.kinetics.asnumpy().squeeze() + self.temperature = cb_params.temperature.asnumpy().squeeze() + cb_params = run_context.original_args() + step = cb_params.cur_step + time = cb_params.cur_time + coordinate = cb_params.coordinate.asnumpy() + self.h5md.write_position(step, time, coordinate) + + if self.use_pbc: + if not self.const_volume: + pbc_box = cb_params.pbc_box.asnumpy() + self.h5md.write_box(step, time, pbc_box) + if self.write_image: + image = self.system.calc_image().asnumpy() + self.h5md.write_image(step, time, image) + + def step_end(self, run_context: RunContext): + """ + Called after each step finished. + + Args: + run_context (RunContext): Include some information of the model. + """ + + if self.count % self.save_freq == 0: + cb_params = run_context.original_args() + step = cb_params.cur_step + time = cb_params.cur_time + + self.potential = cb_params.energy.asnumpy().squeeze() + if self.use_updater: + self.tot_energy = self.potential + self.kinetics + if self.use_pbc: + self.pressure = cb_params.pressure.asnumpy().squeeze() + self.volume = cb_params.volume.asnumpy().squeeze() + + obs_values = [ + self.potential, + self.kinetics, + self.tot_energy, + self.temperature, + self.pressure, + self.volume, + ] + + self.h5md.write_observables(self.observables, step, time, obs_values) + + if self.write_velocity: + velocity = cb_params.velocity[0].asnumpy() + self.h5md.write_velocity(step, time, velocity) + if self.write_force: + force = cb_params.force.asnumpy() + self.h5md.write_force(step, time, force) + + self.count += 1 + + def end(self, run_context: RunContext): + """ + Called once after network training. + + Args: + run_context (RunContext): Include some information of the model. + """ + #pylint: disable=unused-argument + self.h5md.close() diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/callback/information.py b/MindSPONGE/applications/research/Grasp/mindsponge1/callback/information.py new file mode 100644 index 0000000000000000000000000000000000000000..a60a2bd5fa1943d33ec50a30239d6d7687952a3f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/callback/information.py @@ -0,0 +1,152 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Callback to print the information of MD simulation +""" + +from mindspore.train.callback import Callback, RunContext + +from ..optimizer import Updater + + +class RunInfo(Callback): + r""" + Callback to print the information of MD simulation. + + Args: + print_freq (int): Frequency to print out the information. Default: 1. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, print_freq: int = 1): + super().__init__() + + self.print_freq = print_freq + + self.potential = None + self.kinetics = None + self.temperature = None + self.pressure = None + self.tot_energy = None + self.volume = None + + self.use_pbc = False + self.use_updater = False + + self.crd = None + + self.count = 0 + + def __enter__(self): + """Return the enter target.""" + return self + + def __exit__(self, *err): + """Release resources here if have any.""" + + def begin(self, run_context: RunContext): + """ + Called once before the network executing. + + Args: + run_context (RunContext): Include some information of the model. + """ + cb_params = run_context.original_args() + self.use_pbc = cb_params.pbc_box is not None + if isinstance(cb_params.optimizer, Updater): + self.use_updater = True + self.kinetics = cb_params.kinetics.asnumpy().squeeze() + self.temperature = cb_params.temperature.asnumpy().squeeze() + if self.use_pbc: + self.volume = cb_params.volume.asnumpy().squeeze() + self.pressure = cb_params.pressure.asnumpy().squeeze() + + def epoch_begin(self, run_context: RunContext): + """ + Called before each epoch beginning. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def epoch_end(self, run_context: RunContext): + """ + Called after each epoch finished. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def step_begin(self, run_context: RunContext): + """ + Called before each step beginning. + + Args: + run_context (RunContext): Include some information of the model. + """ + if self.count % self.print_freq == 0: + cb_params = run_context.original_args() + self.crd = cb_params.coordinate[0].asnumpy().squeeze() + if self.use_updater: + self.kinetics = cb_params.kinetics.asnumpy().squeeze() + self.temperature = cb_params.temperature.asnumpy().squeeze() + + def step_end(self, run_context: RunContext): + """ + Called after each step finished. + + Args: + run_context (RunContext): Include some information of the model. + """ + + if self.count % self.print_freq == 0: + cb_params = run_context.original_args() + step = cb_params.cur_step + self.potential = cb_params.energy.asnumpy().squeeze() + if self.use_updater: + self.tot_energy = self.potential + self.kinetics + info = 'Step: '+str(step) + ', ' + info += 'E_pot: ' + str(self.potential) + ', ' + if self.use_updater: + self.tot_energy = self.potential + self.kinetics + info += 'E_kin: ' + str(self.kinetics) + ', ' + info += 'E_tot: ' + str(self.tot_energy) + ', ' + info += 'Temperature: ' + str(self.temperature) + if self.use_pbc: + self.pressure = cb_params.pressure.asnumpy().squeeze() + info += ', Pressure: ' + str(self.pressure) + self.volume = cb_params.volume.asnumpy().squeeze() + info += ', Volume: ' + str(self.volume) + print(info) + + self.count += 1 + + def end(self, run_context: RunContext): + """ + Called once after network training. + + Args: + run_context (RunContext): Include some information of the model. + """ diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba0310b2ddbdc29dded8101451f232dd9a9163f1 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .basic import Attention, GlobalAttention +from .msa import MSARowAttentionWithPairBias, MSAColumnAttention, MSAColumnGlobalAttention +from .triangle import TriangleAttention, TriangleMultiplication, OuterProductMean +from .equivariant import InvariantPointAttention +from .transition import Transition +# from .dense import ProcessSBR, AddInterface +from .interface import AddInterface +from .sbr import ProcessSBR + +__all__ = ['Attention', 'GlobalAttention', 'MSARowAttentionWithPairBias', + 'MSAColumnAttention', 'MSAColumnGlobalAttention', + 'TriangleAttention', 'TriangleMultiplication', 'OuterProductMean', + 'InvariantPointAttention', 'Transition', 'AddInterface', 'ProcessSBR'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/amp.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/amp.py new file mode 100644 index 0000000000000000000000000000000000000000..334c7f11b22600e3d0e9d8bf10faf7090f53d707 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/amp.py @@ -0,0 +1,49 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""amp""" + +import mindspore.common.dtype as mstype +from mindspore import nn +from mindspore.ops import functional as F + + +class OutputTo16(nn.Cell): + "Wrap cell for amp. Cast network output back to float16" + + def __init__(self, op): + super(OutputTo16, self).__init__(auto_prefix=False) + self._op = op + + def construct(self, *x): + return F.cast(self._op(*x), mstype.float16) + + +def amp_convert(network, white_list=None): + """Do keep cell fp32.""" + network.to_float(mstype.float16) + if white_list is not None: + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, white_list): + network._cells[name] = OutputTo16(subcell.to_float(mstype.float32)) + change = True + else: + amp_convert(subcell, white_list) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/basic.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..06780fefb38380f7762643706d0c3e5d9d07ca33 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/basic.py @@ -0,0 +1,927 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""basic""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from .initializer import glorot_uniform +from .dense import ProcessSBR + + +class Attention(nn.Cell): + r""" + This is an implementation of multihead attention in the paper `Attention is all you need + `_. Given the query vector with source length, + and the key with key length and the target length, the attention will be performed as + the following. + + .. math:: + + Attention(query, key, vector) = Concat(head_1, \dots, head_h)W^O + + where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias. + + if query, key and value tensor is same, then it will be modified version of self + attention. + + Args: + num_head(int): The number of the heads. + hidden_size(int): The hidden size of the input. + gating(bool): Indicator of if the attention is gated. + q_data_dim(int): The last dimension length of the query tensor. + m_data_dim(int): The last dimension length of the key and value tensor. + output_dim(int): The last dimension length of the output tensor. + batch_size(int): The batch size of parameters in attention, used in while + control flow. Default: None. + + Inputs: + - **q_data** (Tensor) - The query tensor with shape (batch_size, + query_seq_length, q_data_dim) with query_seq_length the query sequence length. + - **m_data** (Tensor) - The key/value tensor with shape (batch_size, + value_seq_length, m_data_dim) with value_seq_length the value sequence length. + - **attention_mask** (Tensor) - The mask for attention matrix with shape + (batch_size, num_head, query_seq_length, value_seq_length). + - **index** (Tensor) - The index of while loop, only used in case of while + control flow. Default: None. + - **nonbatched_bias** (Tensor) - Non-batched bias for the attention matrix with + shape(num_heads, query_seq_length, value_seq_length). Default: None. + + Outputs: + Tensor, output tensor of the Attention layer with shape (batch_size, query_seq_length, hidden_size). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import Attention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = Attention(num_head=4, hidden_size=64, gating=True, q_data_dim=64, + ... m_data_dim=64, output_dim=64) + >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) + >>> m_data = Tensor(np.ones((32, 256, 64)), mstype.float32) + >>> attention_mask = Tensor(np.ones((32, 4, 128, 256)), mstype.float32) + >>> attn_out= model(q_data, m_data, attention_mask) + >>> print(attn_out.shape) + (32, 128, 64) + """ + + def __init__(self, num_head, hidden_size, gating, q_data_dim, m_data_dim, output_dim, + device_num, batch_size=None): + super(Attention, self).__init__() + self.q_data_dim = q_data_dim + self.m_data_dim = m_data_dim + self.output_dim = output_dim + self.num_head = num_head + self.gating = gating + self.hidden_size = hidden_size + self.dim_per_head = self.hidden_size // self.num_head + self.batch_size = batch_size + + self.batch_matmul = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1))) + self.mul2 = P.Mul().shard(((1, device_num, 1), ())) + self.trans = P.Transpose().shard(((1, device_num, 1, 1),)) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True).shard(((1,1,device_num,1), (1,1,1,1))) + self.add1 = P.Add().shard(((1,1,device_num,1), (1,device_num,1))) + self.add = P.Add().shard(((1,1,device_num,1), (1,1,1,1))) + self.softmax = P.Softmax(-1).shard(((1,1,device_num,1),)) + self.trans1 = P.Transpose().shard(((1, 1, device_num, 1),)) + self.add2 = P.Add().shard(((1,device_num,1,1), (1,1))) + self.add3 = P.Add().shard(((1,device_num,1), (1,1))) + self.sigmoid = P.Sigmoid().shard(((1, device_num, 1, 1),)) + self.matmul = P.MatMul(transpose_b=True) + self.mul = P.Mul().shard(((1, device_num, 1, 1), (1, device_num, 1, 1))) + # self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + # self.softmax = nn.Softmax() + # self.sigmoid = nn.Sigmoid() + self.batch_size = batch_size + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + self._init_parameter() + + def construct(self, q_data, m_data, attention_mask, index=None, nonbatched_bias=None): + '''construct''' + if self.batch_size: + + linear_q_weight = self.gather3(self.linear_q_weights, index, 0) + linear_k_weight = self.gather3(self.linear_k_weights, index, 0) + linear_v_weight = self.gather3(self.linear_v_weights, index, 0) + linear_output_weight = self.gather3(self.linear_output_weights, index, 0) + o_bias = self.gather2(self.o_biases, index, 0) + linear_gating_weight = 0 + gating_bias = 0 + if self.gating: + linear_gating_weight =self.gather3(self.linear_gating_weights, index, 0) + gating_bias = self.gather3(self.gating_biases, index, 0) + else: + linear_q_weight = self.linear_q_weights + linear_k_weight = self.linear_k_weights + linear_v_weight = self.linear_v_weights + linear_output_weight = self.linear_output_weights + o_bias = self.o_biases + linear_gating_weight = 0 + gating_bias = 0 + if self.gating: + linear_gating_weight = self.linear_gating_weights + gating_bias = self.gating_biases + + dim_b, dim_q, dim_a = q_data.shape + _, dim_k, dim_c = m_data.shape + dim_h = self.num_head + + # q_data = P.Reshape()(q_data, (-1, dim_a)) # (128, 248, 256) + # m_data = P.Reshape()(m_data, (-1, dim_c)) + + # q = self.matmul(q_data, linear_q_weight) * self.dim_per_head ** (-0.5) + q = self.mul2(self.batch_matmul(q_data, linear_q_weight), self.dim_per_head ** (-0.5)) # (62, 62, 64) + # k = self.matmul(m_data, linear_k_weight) + k = self.batch_matmul(m_data, linear_k_weight) + # v = self.matmul(m_data, linear_v_weight) + v = self.batch_matmul(m_data, linear_v_weight) + + q = P.Reshape()(q, (dim_b, dim_q, dim_h, -1)) # (62, 62, 4, 16) + k = P.Reshape()(k, (dim_b, dim_k, dim_h, -1)) + v = P.Reshape()(v, (dim_b, dim_k, dim_h, -1)) + + # tmp_q = P.Transpose()(q, (0, 2, 1, 3)) + # tmp_k = P.Transpose()(k, (0, 2, 1, 3)) + + tmp_q = self.trans(q, (0, 2, 1, 3)) # (62, 4, 62, 16) + tmp_k = self.trans(k, (0, 2, 1, 3)) + logits = self.batch_matmul_trans_b(tmp_q, tmp_k) # (62, 4, 62, 248) + + if nonbatched_bias is not None: + # bias = P.ExpandDims()(nonbatched_bias, 0) + # logits = P.Add()(logits, bias) + logits = self.add1(logits, nonbatched_bias) + + # logits = P.Add()(logits, attention_mask) + logits = self.add(logits, attention_mask) + weights = self.softmax(logits) + # tmp_v = P.Transpose()(v, (0, 2, 3, 1)) + tmp_v = self.trans(v, (0, 2, 3, 1)) + + # weighted_avg = P.Transpose()(self.batch_matmul_trans_b(weights, tmp_v), (0, 2, 1, 3)) + weighted_avg = self.trans1(self.batch_matmul_trans_b(weights, tmp_v), (0, 2, 1, 3)) + + + if self.gating: + # gating_bias = P.ExpandDims()(P.ExpandDims()(gating_bias, 0), 0) + # gate_values = P.Add()(P.Reshape()(self.batch_matmul(q_data, linear_gating_weight), + # (dim_b, dim_q, dim_h, -1)), + # gating_bias) + gate_values = self.add2(P.Reshape()(self.batch_matmul(q_data, linear_gating_weight), + (dim_b, dim_q, dim_h, -1)), + gating_bias) + gate_values = self.sigmoid(gate_values) + weighted_avg = self.mul(weighted_avg, gate_values) + # weighted_avg = P.Reshape()(weighted_avg * gate_values, (dim_b * dim_q, -1)) + + weighted_avg = P.Reshape()(weighted_avg, (dim_b * dim_q, -1)) + # output = P.Add()(P.Reshape()(self.matmul(weighted_avg, linear_output_weight), + # (dim_b, dim_q, -1)), + # P.ExpandDims()(o_bias, 0)) + output = self.add3(P.Reshape()(self.matmul(weighted_avg, linear_output_weight), + (dim_b, dim_q, -1)), + P.ExpandDims()(o_bias, 0)) + return output + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.linear_q_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.q_data_dim]), mstype.float32)) + self.linear_k_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.m_data_dim]), mstype.float32)) + self.linear_v_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.m_data_dim]), mstype.float32)) + self.linear_output_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.output_dim, + self.num_head * \ + self.dim_per_head]), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim]), + mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * \ + self.dim_per_head, + self.q_data_dim]), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, + self.num_head, + self.dim_per_head)), + mstype.float32), name="gating_b") + else: + self.linear_q_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.q_data_dim, self.dim_per_head * self.q_data_dim, + [self.num_head * self.dim_per_head, self.q_data_dim]), + mstype.float32)) + self.linear_k_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.m_data_dim, self.dim_per_head * self.m_data_dim, + [self.num_head * self.dim_per_head, self.m_data_dim]), + mstype.float32)) + self.linear_v_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.m_data_dim, self.dim_per_head * self.m_data_dim, + [self.num_head * self.dim_per_head, self.m_data_dim]), + mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros([self.output_dim, self.num_head * self.dim_per_head]), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros([self.output_dim]), mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter( + Tensor(np.zeros([self.num_head * self.dim_per_head, self.q_data_dim]), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.ones((self.num_head, self.dim_per_head)), + mstype.float32), + name="gating_b") + + +class Attention2(nn.Cell): + r""" + This is an implementation of multihead attention in the paper `Attention is all you need + `_. Given the query vector with source length, + and the key with key length and the target length, the attention will be performed as + the following. + + .. math:: + + Attention(query, key, vector) = Concat(head_1, \dots, head_h)W^O + + where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias. + + if query, key and value tensor is same, then it will be modified version of self + attention. + + Args: + num_head(int): The number of the heads. + hidden_size(int): The hidden size of the input. + gating(bool): Indicator of if the attention is gated. + q_data_dim(int): The last dimension length of the query tensor. + m_data_dim(int): The last dimension length of the key and value tensor. + output_dim(int): The last dimension length of the output tensor. + batch_size(int): The batch size of parameters in attention, used in while + control flow. Default: None. + + Inputs: + - **q_data** (Tensor) - The query tensor with shape (batch_size, + query_seq_length, q_data_dim) with query_seq_length the query sequence length. + - **m_data** (Tensor) - The key/value tensor with shape (batch_size, + value_seq_length, m_data_dim) with value_seq_length the value sequence length. + - **attention_mask** (Tensor) - The mask for attention matrix with shape + (batch_size, num_head, query_seq_length, value_seq_length). + - **index** (Tensor) - The index of while loop, only used in case of while + control flow. Default: None. + - **nonbatched_bias** (Tensor) - Non-batched bias for the attention matrix with + shape(num_heads, query_seq_length, value_seq_length). Default: None. + + Outputs: + Tensor, output tensor of the Attention layer with shape (batch_size, query_seq_length, hidden_size). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import Attention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = Attention(num_head=4, hidden_size=64, gating=True, q_data_dim=64, + ... m_data_dim=64, output_dim=64) + >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) + >>> m_data = Tensor(np.ones((32, 256, 64)), mstype.float32) + >>> attention_mask = Tensor(np.ones((32, 4, 128, 256)), mstype.float32) + >>> attn_out= model(q_data, m_data, attention_mask) + >>> print(attn_out.shape) + (32, 128, 64) + """ + + def __init__(self, num_head, hidden_size, gating, q_data_dim, m_data_dim, output_dim, + device_num, batch_size=None): + super(Attention2, self).__init__() + self.q_data_dim = q_data_dim + self.m_data_dim = m_data_dim + self.output_dim = output_dim + self.num_head = num_head + self.gating = gating + self.hidden_size = hidden_size + self.dim_per_head = self.hidden_size // self.num_head + self.batch_size = batch_size + + self.batch_matmul = P.BatchMatMul(transpose_b=True).shard(((device_num, 1, 1), (1, 1))) + self.mul2 = P.Mul().shard(((device_num, 1, 1), ())) + self.trans = P.Transpose().shard(((device_num, 1, 1, 1),)) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True).shard(((device_num,1,1,1), (device_num,1,1,1))) + self.add1 = P.Add().shard(((device_num,1,1,1), (1,1,1))) + self.add = P.Add().shard(((device_num,1,1,1), (device_num,1,1,1))) + self.softmax = P.Softmax(-1).shard(((device_num,1,1,1),)) + self.trans1 = P.Transpose().shard(((device_num, 1, 1, 1),)) + self.add2 = P.Add().shard(((device_num,1,1,1), (1,1))) + self.add3 = P.Add().shard(((device_num,1,1), (1,1))) + self.sigmoid = P.Sigmoid().shard(((device_num, 1, 1, 1),)) + self.matmul = P.MatMul(transpose_b=True) + self.mul = P.Mul().shard(((device_num, 1, 1, 1), (device_num, 1, 1, 1))) + # self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + # self.softmax = nn.Softmax() + # self.sigmoid = nn.Sigmoid() + self.batch_size = batch_size + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + self._init_parameter() + + def construct(self, q_data, m_data, attention_mask, index=None, nonbatched_bias=None): + '''construct''' + if self.batch_size: + + linear_q_weight = self.gather3(self.linear_q_weights, index, 0) + linear_k_weight = self.gather3(self.linear_k_weights, index, 0) + linear_v_weight = self.gather3(self.linear_v_weights, index, 0) + linear_output_weight = self.gather3(self.linear_output_weights, index, 0) + o_bias = self.gather2(self.o_biases, index, 0) + linear_gating_weight = 0 + gating_bias = 0 + if self.gating: + linear_gating_weight =self.gather3(self.linear_gating_weights, index, 0) + gating_bias = self.gather3(self.gating_biases, index, 0) + else: + linear_q_weight = self.linear_q_weights + linear_k_weight = self.linear_k_weights + linear_v_weight = self.linear_v_weights + linear_output_weight = self.linear_output_weights + o_bias = self.o_biases + linear_gating_weight = 0 + gating_bias = 0 + if self.gating: + linear_gating_weight = self.linear_gating_weights + gating_bias = self.gating_biases + + dim_b, dim_q, dim_a = q_data.shape + _, dim_k, dim_c = m_data.shape + dim_h = self.num_head + + # q_data = P.Reshape()(q_data, (-1, dim_a)) # (128, 248, 256) + # m_data = P.Reshape()(m_data, (-1, dim_c)) + + # q = self.matmul(q_data, linear_q_weight) * self.dim_per_head ** (-0.5) + q = self.mul2(self.batch_matmul(q_data, linear_q_weight), self.dim_per_head ** (-0.5)) # (62, 62, 64) + # k = self.matmul(m_data, linear_k_weight) + k = self.batch_matmul(m_data, linear_k_weight) + # v = self.matmul(m_data, linear_v_weight) + v = self.batch_matmul(m_data, linear_v_weight) + + q = P.Reshape()(q, (dim_b, dim_q, dim_h, -1)) # (62, 62, 4, 16) + k = P.Reshape()(k, (dim_b, dim_k, dim_h, -1)) + v = P.Reshape()(v, (dim_b, dim_k, dim_h, -1)) + + # tmp_q = P.Transpose()(q, (0, 2, 1, 3)) + # tmp_k = P.Transpose()(k, (0, 2, 1, 3)) + + tmp_q = self.trans(q, (0, 2, 1, 3)) # (62, 4, 62, 16) + tmp_k = self.trans(k, (0, 2, 1, 3)) + logits = self.batch_matmul_trans_b(tmp_q, tmp_k) # (62, 4, 62, 248) + + if nonbatched_bias is not None: + # bias = P.ExpandDims()(nonbatched_bias, 0) + # logits = P.Add()(logits, bias) + logits = self.add1(logits, nonbatched_bias) + + # logits = P.Add()(logits, attention_mask) + logits = self.add(logits, attention_mask) + weights = self.softmax(logits) + # tmp_v = P.Transpose()(v, (0, 2, 3, 1)) + tmp_v = self.trans(v, (0, 2, 3, 1)) + + # weighted_avg = P.Transpose()(self.batch_matmul_trans_b(weights, tmp_v), (0, 2, 1, 3)) + weighted_avg = self.trans1(self.batch_matmul_trans_b(weights, tmp_v), (0, 2, 1, 3)) + + + if self.gating: + # gating_bias = P.ExpandDims()(P.ExpandDims()(gating_bias, 0), 0) + # gate_values = P.Add()(P.Reshape()(self.batch_matmul(q_data, linear_gating_weight), + # (dim_b, dim_q, dim_h, -1)), + # gating_bias) + gate_values = self.add2(P.Reshape()(self.batch_matmul(q_data, linear_gating_weight), + (dim_b, dim_q, dim_h, -1)), + gating_bias) + gate_values = self.sigmoid(gate_values) + weighted_avg = self.mul(weighted_avg, gate_values) + # weighted_avg = P.Reshape()(weighted_avg * gate_values, (dim_b * dim_q, -1)) + + weighted_avg = P.Reshape()(weighted_avg, (dim_b * dim_q, -1)) + # output = P.Add()(P.Reshape()(self.matmul(weighted_avg, linear_output_weight), + # (dim_b, dim_q, -1)), + # P.ExpandDims()(o_bias, 0)) + output = self.add3(P.Reshape()(self.matmul(weighted_avg, linear_output_weight), + (dim_b, dim_q, -1)), + P.ExpandDims()(o_bias, 0)) + return output + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.linear_q_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.q_data_dim]), mstype.float32)) + self.linear_k_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.m_data_dim]), mstype.float32)) + self.linear_v_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.m_data_dim]), mstype.float32)) + self.linear_output_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.output_dim, + self.num_head * \ + self.dim_per_head]), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim]), + mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * \ + self.dim_per_head, + self.q_data_dim]), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, + self.num_head, + self.dim_per_head)), + mstype.float32), name="gating_b") + else: + self.linear_q_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.q_data_dim, self.dim_per_head * self.q_data_dim, + [self.num_head * self.dim_per_head, self.q_data_dim]), + mstype.float32)) + self.linear_k_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.m_data_dim, self.dim_per_head * self.m_data_dim, + [self.num_head * self.dim_per_head, self.m_data_dim]), + mstype.float32)) + self.linear_v_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.m_data_dim, self.dim_per_head * self.m_data_dim, + [self.num_head * self.dim_per_head, self.m_data_dim]), + mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros([self.output_dim, self.num_head * self.dim_per_head]), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros([self.output_dim]), mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter( + Tensor(np.zeros([self.num_head * self.dim_per_head, self.q_data_dim]), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.ones((self.num_head, self.dim_per_head)), + mstype.float32), + name="gating_b") + +# class GlobalAttention(nn.Cell): +# r""" +# This is an implementation of global gated self attention in the paper `Highly accurate +# protein structure prediction with AlphaFold +# `_. For this attention, the +# shape of the query tensor, key tensor and the value tensor should be the same. + +# Args: +# num_head(int): The number of the heads. +# gating(bool): Indicator of if the attention is gated. +# input_dim(int): The last dimension length of the input tensor. +# output_dim(int): The last dimension length of the output tensor. +# batch_size(int): The batch size of parameters in attention, used in while control +# flow. Default: None. + +# Inputs: +# - **q_data** (Tensor) - The query tensor with shape (batch_size, seq_length, +# input_dim) with seq_length the sequence length. +# - **m_data** (Tensor) - The key/value tensor with shape (batch_size, seq_length, +# input_dim). +# - **q_mask** (Tensor) - A binary mask for q_data of shape (batch_size, +# seq_length, 1). +# - **bias** (Tensor) - Bias for the attention matrix. Default: None. +# - **index** (Tensor) - The index of while loop, only used in case of while control +# flow. Default: None. + +# Outputs: +# Tensor, Output tensor of the GlobalAttention layer with shape (batch_size, seq_length, output_dim). + +# Supported Platforms: +# ``Ascend`` ``GPU`` + +# Examples: +# >>> import numpy as np +# >>> from mindsponge.cell import GlobalAttention +# >>> from mindspore import dtype as mstype +# >>> from mindspore import Tensor +# >>> model = GlobalAttention(num_head=4, input_dim=64, gating=True, output_dim=256) +# >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) +# >>> m_data = Tensor(np.ones((32, 128, 64)), mstype.float32) +# >>> q_mask = Tensor(np.ones((32, 128, 1)), mstype.float32) +# >>> attn_out= model(q_data, m_data, q_mask) +# >>> print(attn_out.shape) +# (32, 128, 256) +# """ + +# def __init__(self, num_head, gating, input_dim, output_dim, batch_size=None): +# super(GlobalAttention, self).__init__() + +# self.input_dim = input_dim +# self.num_head = num_head +# self.dim_per_head = self.input_dim // self.num_head +# self.output_dim = output_dim +# self.matmul_trans_b = P.MatMul(transpose_b=True) +# self.batch_matmul = P.BatchMatMul() +# self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) +# self.matmul = P.MatMul() +# self.softmax = nn.Softmax() +# self.sigmoid = nn.Sigmoid() +# self.gating = gating +# self.batch_size = batch_size +# self._init_parameter() +# self.gather3 = P.Gather().shard(((1, 1, 1), ())) +# self.gather2 = P.Gather().shard(((1, 1), ())) + +# def construct(self, q_data, m_data, q_mask, index=None): +# '''construct''' +# if self.batch_size: +# q_weights = self.gather3(self.linear_q_weights, index, 0) +# k_weights = self.gather3(self.linear_k_weights, index, 0) +# v_weights = self.gather3(self.linear_v_weights, index, 0) +# output_weights = self.gather3(self.linear_output_weights, index, 0) +# output_bias = self.gather2(self.o_biases, index, 0) +# gating_weights = 0 +# gating_bias = 0 +# if self.gating: +# gating_weights = self.gather3(self.linear_gating_weights, index, 0) +# gating_bias = self.gather3(self.gating_biases, index, 0) +# else: +# q_weights = self.linear_q_weights +# k_weights = self.linear_k_weights +# v_weights = self.linear_v_weights +# output_weights = self.linear_output_weights +# output_bias = self.o_biases +# gating_weights = 0 +# gating_bias = 0 +# if self.gating: +# gating_weights = self.linear_gating_weights +# gating_bias = self.gating_biases + +# b, _, _ = m_data.shape # (62, 2048, 64) + +# v_weights = P.BroadcastTo((b, +# self.dim_per_head * self.num_head, +# self.dim_per_head))(v_weights) # (1, 64, 8) -> (62, 64, 8) +# v = self.batch_matmul(m_data, v_weights) + +# mask_shape = q_mask.shape # (62, 2048, 1) +# value_shape = q_data.shape # (62, 2048, 64) +# broadcast_factor = 1. +# value_size = value_shape[1] +# mask_size = mask_shape[1] +# if mask_size == 1: +# broadcast_factor = broadcast_factor * value_size +# qa = P.ReduceSum()(q_mask * q_data, 1) +# qb = P.ReduceSum()(q_mask, 1) * broadcast_factor + 1e-10 +# q_avg = P.RealDiv()(qa, qb) + +# q = P.Reshape()(self.matmul(q_avg, q_weights), +# (-1, self.num_head, self.dim_per_head)) * (self.dim_per_head ** (-0.5)) + +# k_weights = P.BroadcastTo((b, +# self.dim_per_head * self.num_head, +# self.dim_per_head))(k_weights) +# k = self.batch_matmul(m_data, k_weights) + +# attention_mask = 1e9 * (P.Transpose()(q_mask, (0, 2, 1)) - 1.0) +# logits = P.Add()(self.batch_matmul_trans_b(q, k), attention_mask) + +# weights = self.softmax(logits) +# weighted_avg = self.batch_matmul(weights, v) + +# if self.gating: +# q_data_shape = P.Shape()(q_data) +# if len(q_data_shape) != 2: +# q_data = P.Reshape()(q_data, (-1, q_data_shape[-1])) +# out_shape = q_data_shape[:-1] + (-1,) +# gate_values = P.Reshape()(self.matmul_trans_b(q_data, gating_weights) + gating_bias, +# out_shape) + +# gate_values = P.Reshape()(self.sigmoid(gate_values), +# (b, -1, self.num_head, self.dim_per_head)) +# weighted_avg = P.Reshape()(P.ExpandDims()(weighted_avg, 1) * gate_values, +# (-1, self.num_head * self.dim_per_head)) +# weighted_avg_shape = P.Shape()(weighted_avg) +# if len(weighted_avg_shape) != 2: +# weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) +# output = P.Reshape()(P.Add()(self.matmul_trans_b(weighted_avg, +# output_weights), output_bias), +# (b, -1, self.output_dim)) +# else: +# weighted_avg = P.Reshape()(weighted_avg, (-1, self.num_head * self.dim_per_head)) +# weighted_avg_shape = P.Shape()(weighted_avg) +# if len(weighted_avg_shape) != 2: +# weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) +# out_shape = weighted_avg_shape[:-1] + (-1,) +# output = P.Reshape()(P.Add()(self.matmul_trans_b(weighted_avg, output_weights), +# output_bias), out_shape) +# output = P.ExpandDims()(output, -1) +# return output + +# def _init_parameter(self): +# '''init parameter''' +# if self.batch_size: +# self.linear_q_weights = Parameter( +# Tensor(np.zeros((self.batch_size, +# self.input_dim, +# self.num_head, +# self.dim_per_head)), +# mstype.float32)) +# self.linear_k_weights = Parameter( +# Tensor(np.zeros((self.batch_size, self.input_dim, self.dim_per_head)), +# mstype.float32)) +# self.linear_v_weights = Parameter( +# Tensor(np.zeros((self.batch_size, self.input_dim, self.dim_per_head)), +# mstype.float32)) +# self.linear_output_weights = Parameter( +# Tensor(np.zeros((self.batch_size, +# self.output_dim, +# self.num_head * self.dim_per_head)), +# mstype.float32)) +# self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.output_dim)), +# mstype.float32)) +# if self.gating: +# self.linear_gating_weights = Parameter( +# Tensor(np.zeros((self.batch_size, +# self.num_head * self.dim_per_head, +# self.input_dim)), +# mstype.float32)) +# self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.input_dim)), +# mstype.float32)) +# else: +# self.linear_q_weights = Parameter(Tensor( +# glorot_uniform(self.num_head * self.input_dim, +# self.dim_per_head * self.input_dim, +# (self.input_dim, self.num_head*self.dim_per_head)), +# mstype.float32)) +# self.linear_k_weights = Parameter( +# Tensor(glorot_uniform(self.input_dim, +# self.dim_per_head, +# (1, self.input_dim, self.dim_per_head)), +# mstype.float32)) +# self.linear_v_weights = Parameter( +# Tensor(glorot_uniform(self.input_dim, +# self.dim_per_head, +# (1, self.input_dim, self.dim_per_head)), +# mstype.float32)) +# self.linear_output_weights = Parameter( +# Tensor(np.zeros((self.output_dim, self.num_head * self.dim_per_head)), +# mstype.float32)) +# self.o_biases = Parameter(Tensor(np.zeros((self.output_dim)), +# mstype.float32)) +# if self.gating: +# self.linear_gating_weights = Parameter( +# Tensor(np.zeros((self.num_head * self.dim_per_head, self.input_dim)), +# mstype.float32)) +# self.gating_biases = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32)) + + +class GlobalAttention(nn.Cell): + r""" + This is an implementation of global gated self attention in the paper `Highly accurate + protein structure prediction with AlphaFold + `_. For this attention, the + shape of the query tensor, key tensor and the value tensor should be the same. + + Args: + num_head(int): The number of the heads. + gating(bool): Indicator of if the attention is gated. + input_dim(int): The last dimension length of the input tensor. + output_dim(int): The last dimension length of the output tensor. + batch_size(int): The batch size of parameters in attention, used in while control + flow. Default: None. + + Inputs: + - **q_data** (Tensor) - The query tensor with shape (batch_size, seq_length, + input_dim) with seq_length the sequence length. + - **m_data** (Tensor) - The key/value tensor with shape (batch_size, seq_length, + input_dim). + - **q_mask** (Tensor) - A binary mask for q_data of shape (batch_size, + seq_length, 1). + - **bias** (Tensor) - Bias for the attention matrix. Default: None. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. Default: None. + + Outputs: + Tensor, Output tensor of the GlobalAttention layer with shape (batch_size, seq_length, output_dim). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import GlobalAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = GlobalAttention(num_head=4, input_dim=64, gating=True, output_dim=256) + >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) + >>> m_data = Tensor(np.ones((32, 128, 64)), mstype.float32) + >>> q_mask = Tensor(np.ones((32, 128, 1)), mstype.float32) + >>> attn_out= model(q_data, m_data, q_mask) + >>> print(attn_out.shape) + (32, 128, 256) + """ + + def __init__(self, num_head, gating, input_dim, output_dim, device_num, batch_size=None): + super(GlobalAttention, self).__init__() + + self.input_dim = input_dim + self.num_head = num_head + self.dim_per_head = self.input_dim // self.num_head + self.output_dim = output_dim + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.batch_matmul = P.BatchMatMul().shard(((1, device_num, 1), (1, 1, 1))) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.batch_matmul_trans_b2 = P.BatchMatMul(transpose_b=True).shard(((1, 1, 1), (1, device_num, 1))) + self.matmul = P.MatMul() + self.softmax = nn.Softmax() + self.sigmoid = nn.Sigmoid() + self.gating = gating + self.batch_size = batch_size + self._init_parameter() + self.reduce_sum = P.ReduceSum().shard(((1, device_num, 1),)) + self.trans = P.Transpose().shard(((1, device_num, 1),)) + self.mul = P.Mul().shard(((1, device_num, 1), (1, device_num, 1))) + self.sub = P.Sub().shard(((1, 1, device_num), ())) + self.mul2 = P.Mul().shard(((1, 1, device_num), ())) + self.add2 = P.Add().shard(((1, 1, device_num), (1, 1, device_num))) + self.batch_matmul2 = P.BatchMatMul().shard(((1, 1, device_num), (1, device_num, 1))) + + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + + def construct(self, q_data, m_data, q_mask, index=None): + '''construct''' + if self.batch_size: + q_weights = self.gather3(self.linear_q_weights, index, 0) + k_weights = self.gather3(self.linear_k_weights, index, 0) + v_weights = self.gather3(self.linear_v_weights, index, 0) + output_weights = self.gather3(self.linear_output_weights, index, 0) + output_bias = self.gather2(self.o_biases, index, 0) + gating_weights = 0 + gating_bias = 0 + if self.gating: + gating_weights = self.gather3(self.linear_gating_weights, index, 0) + gating_bias = self.gather3(self.gating_biases, index, 0) + else: + q_weights = self.linear_q_weights + k_weights = self.linear_k_weights + v_weights = self.linear_v_weights + output_weights = self.linear_output_weights + output_bias = self.o_biases + gating_weights = 0 + gating_bias = 0 + if self.gating: + gating_weights = self.linear_gating_weights + gating_bias = self.gating_biases + + b, _, _ = m_data.shape # (62, 2048, 64) + + v_weights = P.BroadcastTo((b, + self.dim_per_head * self.num_head, + self.dim_per_head))(v_weights) # (1, 64, 8) -> (62, 64, 8) + v = self.batch_matmul(m_data, v_weights) + + + mask_shape = q_mask.shape # (62, 2048, 1) + value_shape = q_data.shape # (62, 2048, 64) + broadcast_factor = 1. + value_size = value_shape[1] + mask_size = mask_shape[1] + if mask_size == 1: + broadcast_factor = broadcast_factor * value_size + + # qa = self.reduce_sum(self.) + # qa = P.ReduceSum()(q_mask * q_data, 1) + qa = self.reduce_sum(self.mul(q_mask, q_data), 1) + # qb = self.add(self.mul2(self.reduce_sum(q_mask, 1), broadcast_factor), 1e-10) + qb = self.reduce_sum(q_mask, 1) * broadcast_factor + 1e-10 + + # qb = P.ReduceSum()(q_mask, 1) * broadcast_factor + 1e-10 + q_avg = P.RealDiv()(qa, qb) + + q = P.Reshape()(self.matmul(q_avg, q_weights), + (-1, self.num_head, self.dim_per_head)) * (self.dim_per_head ** (-0.5)) + + k_weights = P.BroadcastTo((b, + self.dim_per_head * self.num_head, + self.dim_per_head))(k_weights) + k = self.batch_matmul(m_data, k_weights) + + # attention_mask = 1e9 * (P.Transpose()(q_mask, (0, 2, 1)) - 1.0) + # logits = P.Add()(self.batch_matmul_trans_b(q, k), attention_mask) + + attention_mask = self.mul2((self.sub(self.trans(q_mask, (0, 2, 1)), 1.0)), 1e9) + logits = self.add2(self.batch_matmul_trans_b2(q, k), attention_mask) + + weights = self.softmax(logits) + weighted_avg = self.batch_matmul2(weights, v) + + if self.gating: + q_data_shape = P.Shape()(q_data) + if len(q_data_shape) != 2: + q_data = P.Reshape()(q_data, (-1, q_data_shape[-1])) + out_shape = q_data_shape[:-1] + (-1,) + gate_values = P.Reshape()(self.matmul_trans_b(q_data, gating_weights) + gating_bias, + out_shape) + + gate_values = P.Reshape()(self.sigmoid(gate_values), + (b, -1, self.num_head, self.dim_per_head)) + weighted_avg = P.Reshape()(P.ExpandDims()(weighted_avg, 1) * gate_values, + (-1, self.num_head * self.dim_per_head)) + weighted_avg_shape = P.Shape()(weighted_avg) + if len(weighted_avg_shape) != 2: + weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) + output = P.Reshape()(P.Add()(self.matmul_trans_b(weighted_avg, + output_weights), output_bias), + (b, -1, self.output_dim)) + else: + weighted_avg = P.Reshape()(weighted_avg, (-1, self.num_head * self.dim_per_head)) + weighted_avg_shape = P.Shape()(weighted_avg) + if len(weighted_avg_shape) != 2: + weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) + out_shape = weighted_avg_shape[:-1] + (-1,) + output = P.Reshape()(P.Add()(self.matmul_trans_b(weighted_avg, output_weights), + output_bias), out_shape) + output = P.ExpandDims()(output, -1) + return output + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.linear_q_weights = Parameter( + Tensor(np.zeros((self.batch_size, + self.input_dim, + self.num_head, + self.dim_per_head)), + mstype.float32)) + self.linear_k_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim, self.dim_per_head)), + mstype.float32)) + self.linear_v_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim, self.dim_per_head)), + mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros((self.batch_size, + self.output_dim, + self.num_head * self.dim_per_head)), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.output_dim)), + mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter( + Tensor(np.zeros((self.batch_size, + self.num_head * self.dim_per_head, + self.input_dim)), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.input_dim)), + mstype.float32)) + else: + self.linear_q_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.input_dim, + self.dim_per_head * self.input_dim, + (self.input_dim, self.num_head*self.dim_per_head)), + mstype.float32)) + self.linear_k_weights = Parameter( + Tensor(glorot_uniform(self.input_dim, + self.dim_per_head, + (1, self.input_dim, self.dim_per_head)), + mstype.float32)) + self.linear_v_weights = Parameter( + Tensor(glorot_uniform(self.input_dim, + self.dim_per_head, + (1, self.input_dim, self.dim_per_head)), + mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros((self.output_dim, self.num_head * self.dim_per_head)), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros((self.output_dim)), + mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter( + Tensor(np.zeros((self.num_head * self.dim_per_head, self.input_dim)), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32)) \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/dense.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..e22087bf31f56dcecca1504d5f06db1754b4cae6 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/dense.py @@ -0,0 +1,120 @@ +import mindspore as ms +import numpy as np +from mindspore import ops, Tensor, Parameter, nn +from mindspore.ops import operations as P + +class NetBatch(nn.Cell): + + def __init__(self, batch_size=None): + super().__init__() + self.batch_size = batch_size + + def _new_shape(self, shape): + if self.batch_size is not None: + shape = [self.batch_size,]+list(shape) + return shape + + def _get_params(self, index): + if index is not None: + ls = [] + for p in self._params.values(): + ls.append(p[index]) + return ls + else: + return self._params.values() + + +class Dense(NetBatch): + # no activation, zero bias init, lecun weight init + def __init__(self, input_dim, output_dim, batch_size=None, is_gate=False): + super().__init__(batch_size) + self.input_dim = input_dim + self.output_dim = output_dim + self.is_gate = is_gate + self.matmul = P.MatMul() + self._init_parameter() + + def construct(self, x, index=None): + w, b = self._get_params(index) + # y = ops.matmul(x, w) + b + + y = P.Reshape()(x, (-1, x.shape[-1])) + y = self.matmul(y, w) + b + y = P.Reshape()(y, x.shape[:-1]+(-1,)) + return y + + def _lecun_normal(self, dim_in, shape): + stddev = 1./np.sqrt(dim_in) + return np.random.normal(loc=0, scale=stddev, size=shape) + + def _init_parameter(self): + w_shape = self._new_shape((self.input_dim, self.output_dim)) + b_shape = self._new_shape((self.output_dim,)) + if self.is_gate: + self.weight = Parameter(Tensor(np.zeros(w_shape), ms.float32)) + self.bias = Parameter(Tensor(np.ones(b_shape), ms.float32)) + else: + # self.weight = Parameter(Tensor(self._lecun_normal(self.input_dim, w_shape), ms.float32)) + self.weight = Parameter(Tensor(np.zeros(w_shape), ms.float32)) + self.bias = Parameter(Tensor(np.zeros(b_shape), ms.float32)) + + +class ProcessSBR(nn.Cell): + + def __init__(self, sbr_act_dim, output_dim, batch_size=None, gate=False, pair_input_dim=0): + super().__init__() + self.sbr_act_dim = sbr_act_dim + self.atte_dim = output_dim + self.linear1 = Dense(sbr_act_dim, output_dim, batch_size) + if gate: + self.linear2 = Dense(sbr_act_dim+pair_input_dim, output_dim, batch_size, is_gate=True) + self.sigmoid = nn.Sigmoid() + + def construct(self, sbr_act, sbr_mask, pair=None, index=None): + y = self.linear1(sbr_act, index) + if pair is not None: + sbr_act = ops.Tile()(sbr_act, pair.shape[:-3]+(1, 1, 1)) + gate = ops.Concat(-1)((sbr_act, pair)) + gate = self.sigmoid(self.linear2(gate, index)) + y *= gate + y *= sbr_mask[..., None] + return y + +class AddInterface(nn.Cell): + + def __init__(self, input_dim, batch_size=None): + super().__init__() + self.linear = Dense(input_dim+1, input_dim, batch_size) + + def construct(self, interface_mask, act, index=None): + mask = interface_mask[..., None] + mask = ops.Tile()(mask, act.shape[:-2]+(1, 1)) + x = ops.Concat(-1)((act, mask)) + y = self.linear(x, index) + y *= mask + return y + + + +# ds = Dense(3, 5, 2) +# x = Tensor(np.arange(24).reshape((2,4,3)), ms.float32) +# y = ds(x, 1) +# y.shape + +# sbr_act = Tensor(np.random.normal(size=(4,4,3)), ms.float32) +# atte = Tensor(np.random.normal(size=(4,4,2)), ms.float32) +# sbr_mask = Tensor(np.random.rand(4,4)<0.5, ms.float32) +# print(sbr_act, atte, sbr_mask) +# psbr = ProcessSBR(3, 2, batch_size=5) +# y = psbr(sbr_act, atte, sbr_mask, index=3) +# print(y, y.shape) + +# single_act = Tensor(np.random.normal(size=(4, 2)), ms.float32) +# msa_act = Tensor(np.random.normal(size=(3, 4, 2)), ms.float32) +# interface_mask = Tensor(np.random.rand(4)<0.5, ms.float32) +# print(single_act, msa_act, interface_mask, sep='\n') +# aif = AddInterface(2, batch_size=5) +# y_single = aif(interface_mask, single_act, index=3) +# y_msa = aif(interface_mask, msa_act, index=3) +# print('single', y_single.shape, y_single) +# print('msa', y_msa.shape, y_msa) \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/dense1.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/dense1.py new file mode 100644 index 0000000000000000000000000000000000000000..73ab6c69b7f1bbaa3282eb94d8b15eb36215b99c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/dense1.py @@ -0,0 +1,114 @@ +import mindspore as ms +import numpy as np +from mindspore import ops, Tensor, Parameter, nn + +class NetBatch(nn.Cell): + + def __init__(self, batch_size=None): + super().__init__() + self.batch_size = batch_size + + def _new_shape(self, shape): + if self.batch_size is not None: + shape = [self.batch_size,]+list(shape) + return shape + + def _get_params(self, index): + if index is not None: + ls = [] + for p in self._params.values(): + ls.append(p[index]) + return ls + else: + return self._params.values() + + +class Dense(NetBatch): + # no activation, zero bias init, lecun weight init + def __init__(self, input_dim, output_dim, batch_size=None, is_gate=False): + super().__init__(batch_size) + self.input_dim = input_dim + self.output_dim = output_dim + self.is_gate = is_gate + self._init_parameter() + + def construct(self, x, index=None): + w, b = self._get_params(index) + y = ops.matmul(x, w) + b + return y + + def _lecun_normal(self, dim_in, shape): + stddev = 1./np.sqrt(dim_in) + return np.random.normal(loc=0, scale=stddev, size=shape) + + def _init_parameter(self): + w_shape = self._new_shape((self.input_dim, self.output_dim)) + b_shape = self._new_shape((self.output_dim,)) + if self.is_gate: + self.weight = Parameter(Tensor(np.zeros(w_shape), ms.float32)) + self.bias = Parameter(Tensor(np.ones(b_shape), ms.float32)) + else: + # self.weight = Parameter(Tensor(self._lecun_normal(self.input_dim, w_shape), ms.float32)) + self.weight = Parameter(Tensor(np.zeros(w_shape), ms.float32)) + self.bias = Parameter(Tensor(np.zeros(b_shape), ms.float32)) + + +class ProcessSBR(nn.Cell): + + def __init__(self, sbr_act_dim, output_dim, batch_size=None, gate=False, pair_input_dim=0): + super().__init__() + self.sbr_act_dim = sbr_act_dim + self.atte_dim = output_dim + self.linear1 = Dense(sbr_act_dim, output_dim, batch_size) + if gate: + self.linear2 = Dense(sbr_act_dim+pair_input_dim, output_dim, batch_size, is_gate=True) + self.sigmoid = nn.Sigmoid() + + def construct(self, sbr_act, sbr_mask, pair=None, index=None): + y = self.linear1(sbr_act, index) + if pair is not None: + sbr_act = ops.Tile()(sbr_act, pair.shape[:-3]+(1, 1, 1)) + gate = ops.Concat(-1)((sbr_act, pair)) + gate = self.sigmoid(self.linear2(gate, index)) + y *= gate + y *= sbr_mask[..., None] + return y + +class AddInterface(nn.Cell): + + def __init__(self, input_dim, batch_size=None): + super().__init__() + self.linear = Dense(input_dim+1, input_dim, batch_size) + + def construct(self, interface_mask, act, index=None): + mask = interface_mask[..., None] + mask = ops.Tile()(mask, act.shape[:-2]+(1, 1)) + x = ops.Concat(-1)((act, mask)) + y = self.linear(x, index) + y *= mask + return y + + + +# ds = Dense(3, 5, 2) +# x = Tensor(np.arange(24).reshape((2,4,3)), ms.float32) +# y = ds(x, 1) +# y.shape + +# sbr_act = Tensor(np.random.normal(size=(4,4,3)), ms.float32) +# atte = Tensor(np.random.normal(size=(4,4,2)), ms.float32) +# sbr_mask = Tensor(np.random.rand(4,4)<0.5, ms.float32) +# print(sbr_act, atte, sbr_mask) +# psbr = ProcessSBR(3, 2, batch_size=5) +# y = psbr(sbr_act, atte, sbr_mask, index=3) +# print(y, y.shape) + +# single_act = Tensor(np.random.normal(size=(4, 2)), ms.float32) +# msa_act = Tensor(np.random.normal(size=(3, 4, 2)), ms.float32) +# interface_mask = Tensor(np.random.rand(4)<0.5, ms.float32) +# print(single_act, msa_act, interface_mask, sep='\n') +# aif = AddInterface(2, batch_size=5) +# y_single = aif(interface_mask, single_act, index=3) +# y_msa = aif(interface_mask, msa_act, index=3) +# print('single', y_single.shape, y_single) +# print('msa', y_msa.shape, y_msa) \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/equivariant.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/equivariant.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1137b4993e84dd86170d0e80010218eb2ac240 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/equivariant.py @@ -0,0 +1,244 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Equivariant""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from ..common.geometry import apply_to_point, invert_point +from .initializer import lecun_init + + +class InvariantPointAttention(nn.Cell): + r""" + Invariant Point attention module. + This module is used to update the sequence representation ,which is the first input--inputs_1d, + adding location information to the sequence representation. + + The attention consists of three parts, namely, q, k, v obtained by the sequence representation, + q'k'v' obtained by the interaction between the sequence representation and the rigid body group, + and b , which is th bias, obtained from the pair representation (the second inputs -- inputs_2d). + + .. math:: + a_{ij} = Softmax(w_l(c_1{q_i}^Tk_j+b{ij}-c_2\sum {\left \| T_i\circ q'_i-T_j\circ k'_j \right \| ^{2 } })) + + where i and j represent the ith and jth amino acids in the sequence, respectively, + and T is the rotation and translation in the input. + + `Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention" + `_. + + Args: + num_head (int): The number of the heads. + num_scalar_qk (int): The number of the scalar query/key. + num_scalar_v (int): The number of the scalar value. + num_point_v (int): The number of the point value. + num_point_qk (int): The number of the point query/key. + num_channel (int): The number of the channel. + pair_dim (int): The last dimension length of pair. + + Inputs: + - **inputs_1d** (Tensor) - The first row of msa representation which is the output of evoformer module, + also called the sequence representation, shape :math:`[N_{res}, num\_channel]`. + - **inputs_2d** (Tensor) - The pair representation which is the output of evoformer module, + shape :math:`[N_{res}, N_{res}, pair\_dim]`. + - **mask** (Tensor) - A mask that determines which elements of inputs_1d are involved in the + attention calculation, shape :math:`[N_{res}, 1]` + - **rotation** (tuple) - A rotation term in a rigid body group T(r,t), + A tuple of length 9, The shape of each elements in the tuple is :math:`[N_{res}]`. + - **translation** (tuple) - A translation term in a rigid body group T(r,t), + A tuple of length 3, The shape of each elements in the tuple is :math:`[N_{res}]`. + + Outputs: + Tensor, the update of inputs_1d, shape :math:`[N_{res}, num\_channel]`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import InvariantPointAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> import mindspore.context as context + >>> context.set_context(mode=context.GRAPH_MODE) + >>> model = InvariantPointAttention(num_head=12, num_scalar_qk=16, num_scalar_v=16, + ... num_point_v=8, num_point_qk=4, + ... num_channel=384, pair_dim=128) + >>> inputs_1d = Tensor(np.ones((256, 384)), mstype.float32) + >>> inputs_2d = Tensor(np.ones((256, 256, 128)), mstype.float32) + >>> mask = Tensor(np.ones((256, 1)), mstype.float32) + >>> rotation = tuple([Tensor(np.ones(256), mstype.float16) for _ in range(9)]) + >>> translation = tuple([Tensor(np.ones(256), mstype.float16) for _ in range(3)]) + >>> attn_out = model(inputs_1d, inputs_2d, mask, rotation, translation) + >>> print(attn_out.shape) + (256, 384) + """ + + def __init__(self, num_head, num_scalar_qk, num_scalar_v, num_point_v, num_point_qk, num_channel, pair_dim): + super(InvariantPointAttention, self).__init__() + + self._dist_epsilon = 1e-8 + self.num_head = num_head + self.num_scalar_qk = num_scalar_qk + self.num_scalar_v = num_scalar_v + self.num_point_v = num_point_v + self.num_point_qk = num_point_qk + self.num_channel = num_channel + self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 + \ + self.num_head * pair_dim + self.q_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_qk, + weight_init=lecun_init(self.num_channel)) + self.kv_scalar = nn.Dense(self.num_channel, self.num_head * (self.num_scalar_qk + self.num_scalar_v), + weight_init=lecun_init(self.num_channel)) + self.q_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_qk, + weight_init=lecun_init(self.num_channel) + ) + self.kv_point_local = nn.Dense(self.num_channel, self.num_head * 3 * (self.num_point_qk + self.num_point_v), + weight_init=lecun_init(self.num_channel)) + self.soft_max = nn.Softmax() + self.soft_plus = ops.Softplus() + self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights") + self.attention_2d = nn.Dense(pair_dim, self.num_head, weight_init=lecun_init(pair_dim)) + self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros' + ) + self.scalar_weights = Tensor(np.sqrt(1.0 / (3 * 16)).astype(np.float32)) + self.point_weights = Tensor(np.sqrt(1.0 / (3 * 18)).astype(np.float32)) + self.attention_2d_weights = Tensor(np.sqrt(1.0 / 3).astype(np.float32)) + + def construct(self, inputs_1d, inputs_2d, mask, rotation, translation): + '''construct''' + num_residues, _ = inputs_1d.shape + + # Improve readability by removing a large number of 'self's. + num_head = self.num_head + num_scalar_qk = self.num_scalar_qk + num_point_qk = self.num_point_qk + num_scalar_v = self.num_scalar_v + num_point_v = self.num_point_v + + # Construct scalar queries of shape: + q_scalar = self.q_scalar(inputs_1d) + q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk]) + + # Construct scalar keys/values of shape: + kv_scalar = self.kv_scalar(inputs_1d) + kv_scalar = mnp.reshape(kv_scalar, [num_residues, num_head, num_scalar_v + num_scalar_qk]) + k_scalar, v_scalar = mnp.split(kv_scalar, [num_scalar_qk], axis=-1) + + # Construct query points of shape: + # First construct query points in local frame. + q_point_local = self.q_point_local(inputs_1d) + + q_point_local = mnp.split(q_point_local, 3, axis=-1) + q_point_local = (ops.Squeeze()(q_point_local[0]), ops.Squeeze()(q_point_local[1]), + ops.Squeeze()(q_point_local[2])) + # Project query points into global frame. + q_point_global = apply_to_point(rotation, translation, q_point_local, 1) + + # Reshape query point for later use. + q_point0 = mnp.reshape(q_point_global[0], (num_residues, num_head, num_point_qk)) + q_point1 = mnp.reshape(q_point_global[1], (num_residues, num_head, num_point_qk)) + q_point2 = mnp.reshape(q_point_global[2], (num_residues, num_head, num_point_qk)) + + # Construct key and value points. + # Key points have shape [num_residues, num_head, num_point_qk] + # Value points have shape [num_residues, num_head, num_point_v] + + # Construct key and value points in local frame. + kv_point_local = self.kv_point_local(inputs_1d) + + kv_point_local = mnp.split(kv_point_local, 3, axis=-1) + kv_point_local = (ops.Squeeze()(kv_point_local[0]), ops.Squeeze()(kv_point_local[1]), + ops.Squeeze()(kv_point_local[2])) + # Project key and value points into global frame. + kv_point_global = apply_to_point(rotation, translation, kv_point_local, 1) + + kv_point_global0 = mnp.reshape(kv_point_global[0], (num_residues, num_head, (num_point_qk + num_point_v))) + kv_point_global1 = mnp.reshape(kv_point_global[1], (num_residues, num_head, (num_point_qk + num_point_v))) + kv_point_global2 = mnp.reshape(kv_point_global[2], (num_residues, num_head, (num_point_qk + num_point_v))) + + # Split key and value points. + k_point0, v_point0 = mnp.split(kv_point_global0, [num_point_qk], axis=-1) + k_point1, v_point1 = mnp.split(kv_point_global1, [num_point_qk], axis=-1) + k_point2, v_point2 = mnp.split(kv_point_global2, [num_point_qk], axis=-1) + + trainable_point_weights = self.soft_plus(self.trainable_point_weights) + point_weights = self.point_weights * mnp.expand_dims(trainable_point_weights, axis=1) + + v_point = [mnp.swapaxes(v_point0, -2, -3), mnp.swapaxes(v_point1, -2, -3), mnp.swapaxes(v_point2, -2, -3)] + q_point = [mnp.swapaxes(q_point0, -2, -3), mnp.swapaxes(q_point1, -2, -3), mnp.swapaxes(q_point2, -2, -3)] + k_point = [mnp.swapaxes(k_point0, -2, -3), mnp.swapaxes(k_point1, -2, -3), mnp.swapaxes(k_point2, -2, -3)] + + dist2 = mnp.square(ops.expand_dims(q_point[0], 2) - ops.expand_dims(k_point[0], 1)) + \ + mnp.square(ops.expand_dims(q_point[1], 2) - ops.expand_dims(k_point[1], 1)) + \ + mnp.square(ops.expand_dims(q_point[2], 2) - ops.expand_dims(k_point[2], 1)) + + attn_qk_point = -0.5 * mnp.sum(ops.expand_dims(ops.expand_dims(point_weights, 1), 1) * dist2, axis=-1) + + v = mnp.swapaxes(v_scalar, -2, -3) + q = mnp.swapaxes(self.scalar_weights * q_scalar, -2, -3) + k = mnp.swapaxes(k_scalar, -2, -3) + attn_qk_scalar = ops.matmul(q, mnp.swapaxes(k, -2, -1)) + attn_logits = attn_qk_scalar + attn_qk_point + + attention_2d = self.attention_2d(inputs_2d) + attention_2d = mnp.transpose(attention_2d, [2, 0, 1]) + attention_2d = self.attention_2d_weights * attention_2d + + attn_logits += attention_2d + + mask_2d = mask * mnp.swapaxes(mask, -1, -2) + attn_logits -= 50 * (1. - mask_2d) + + attn = self.soft_max(attn_logits) + + result_scalar = ops.matmul(attn, v) + + result_point_global = [mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[0][:, None, :, :], axis=-2), -2, -3), + mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[1][:, None, :, :], axis=-2), -2, -3), + mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[2][:, None, :, :], axis=-2), -2, -3) + ] + + result_point_global = [mnp.reshape(result_point_global[0], [num_residues, num_head * num_point_v]), + mnp.reshape(result_point_global[1], [num_residues, num_head * num_point_v]), + mnp.reshape(result_point_global[2], [num_residues, num_head * num_point_v])] + result_scalar = mnp.swapaxes(result_scalar, -2, -3) + + result_scalar = mnp.reshape(result_scalar, [num_residues, num_head * num_scalar_v]) + + result_point_local = invert_point(result_point_global, rotation, translation, 1) + + output_feature1 = result_scalar + output_feature20 = result_point_local[0] + output_feature21 = result_point_local[1] + output_feature22 = result_point_local[2] + + output_feature3 = mnp.sqrt(self._dist_epsilon + + mnp.square(result_point_local[0]) + + mnp.square(result_point_local[1]) + + mnp.square(result_point_local[2])) + + result_attention_over_2d = ops.matmul(mnp.swapaxes(attn, 0, 1), inputs_2d) + num_out = num_head * result_attention_over_2d.shape[-1] + output_feature4 = mnp.reshape(result_attention_over_2d, [num_residues, num_out]) + + final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21, + output_feature22, output_feature3, output_feature4], axis=-1) + final_result = self.output_projection(final_act) + return final_result diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/initializer.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..50cf5c7db650c0a3044e0e5d6e769977ebc5427f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/initializer.py @@ -0,0 +1,35 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""initializer""" + +import numpy as np +from mindspore.common.initializer import TruncatedNormal + +TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978, dtype=np.float32) + + +def lecun_init(fan_in, initializer_name='linear'): + """lecun init""" + scale = 1.0 + if initializer_name == 'relu': + scale *= 2 + weight_init = TruncatedNormal(sigma=np.sqrt(scale / fan_in) / TRUNCATED_NORMAL_STDDEV_FACTOR) + return weight_init + + +def glorot_uniform(fan_in, fan_out, weight_shape): + """glorot uniform""" + limit = np.sqrt(6 / (fan_in + fan_out)) + return np.random.uniform(-limit, limit, size=weight_shape) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/interface.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6d666e02813a0cd0c0145947d761d398218b7e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/interface.py @@ -0,0 +1,83 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Interface""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore import Parameter +from mindspore.ops import operations as P +# from .mask import MaskedLayerNorm +from .sbr import lecun_normal + + +class AddInterface(nn.Cell): + '''add interface information into msa representation or single representation''' + + def __init__(self, input_dim, batch_size=None): + super(AddInterface, self).__init__() + self.matmul = P.MatMul(transpose_b=True) + self.input_dim = input_dim + self.batch_size = batch_size + # self.idx = Tensor(0, mstype.int32) + # self.masked_layer_norm = MaskedLayerNorm() + self._init_parameter() + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + + def construct(self, interface_mask, act, index=None, mask=None): + '''Compute linear''' + if self.batch_size: + # input_layer_norm_gamma = P.Gather()(self.input_layer_norm_gammas, index, 0) + # input_layer_norm_beta = P.Gather()(self.input_layer_norm_betas, index, 0) + linear_weight = self.gather3(self.linear_weights, index, 0) + linear_bias = self.gather2(self.linear_biases, index, 0) + else: + # input_layer_norm_gamma = self.input_layer_norm_gammas + # input_layer_norm_beta = self.input_layer_norm_betas + linear_weight = self.linear_weights + linear_bias = self.linear_biases + # act = self.masked_layer_norm(act, input_layer_norm_gamma, input_layer_norm_beta, mask=mask) + + act_shape = P.Shape()(act) + interface_mask = P.ExpandDims()(interface_mask, -1) + while len(act_shape) > len(P.Shape()(interface_mask)): + interface_mask = P.ExpandDims()(interface_mask, 0) + mask1 = interface_mask + interface_mask = P.Tile()(interface_mask, act_shape[: -2] + (1, 1)) + # act = P.Cast()(act, mstype.float16) + act = P.Concat(-1)((act, interface_mask)) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1]+1)) + act = P.BiasAdd()(self.matmul(act, linear_weight), linear_bias) + act = P.Reshape()(act, act_shape) + return act * mask1 + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + # self.input_layer_norm_gammas = Parameter( + # Tensor(np.ones((self.batch_size, self.input_dim)), mstype.float32)) + # self.input_layer_norm_betas = Parameter( + # Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + self.linear_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim, self.input_dim + 1)), mstype.float32)) + self.linear_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + else: + # self.input_layer_norm_gammas = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32)) + # self.input_layer_norm_betas = Parameter(Tensor(np.zeros((self.input_dim)), mstype.float32)) + self.linear_weights = Parameter(Tensor(np.zeros((self.input_dim, self.input_dim + 1)), mstype.float32)) + self.linear_biases = Parameter(Tensor(np.zeros((self.input_dim,)), mstype.float32)) \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/mask.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..17716dc9c1e5e594855f5caf5222fb2427de9cd1 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/mask.py @@ -0,0 +1,95 @@ +# Copyright 2022 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Mask""" +# from mindspore.ops import operations as P +from mindspore import ops as P +from mindspore.ops import functional as F +import mindspore.nn as nn + +class LayerNormProcess(nn.Cell): + def __init__(self,): + super(LayerNormProcess, self).__init__() + self.layernorm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + + def construct(self, msa_act, query_norm_gamma, query_norm_beta): + # print("debug LayerNormProcess msa_act", msa_act) + # print("debug LayerNormProcess query_norm_gamma", query_norm_gamma[:]) + # print("debug LayerNormProcess query_norm_beta", query_norm_beta[:]) + output, _, _ = self.layernorm(msa_act, query_norm_gamma, query_norm_beta) + # print("debug LayerNormProcess output", output) + return output + + +class MaskedLayerNorm(nn.Cell): + '''masked_layer_norm''' + + def __init__(self): + super(MaskedLayerNorm, self).__init__() + #self.norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + self.norm = LayerNormProcess() + + def construct(self, act, gamma, beta, mask=None): + '''construct''' + act = act + gamma = gamma + beta = beta + # print("debug MaskedLayerNorm act", act) + ones = P.Ones()(act.shape[:-1] + (1,), act.dtype) + if mask is not None: + mask = F.expand_dims(mask, -1) + mask = mask * ones + else: + mask = ones + # print("debug MaskedLayerNorm mask", mask) + act = act * mask + act = self.norm(act, gamma, beta) + act = act * mask + # print("debug MaskedLayerNorm act 54", act) + return act + +class MaskedLayerNormParallel(nn.Cell): + '''masked_layer_norm''' + + def __init__(self, device_num): + super(MaskedLayerNormParallel, self).__init__() + self.norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5).shard(((1, device_num, 1), (1,), (1,))) + self.expand = P.ExpandDims().shard(((1, device_num),)) + self.mul = P.Mul().shard(((1, device_num, 1), (1, device_num, 1))) + # self.norm = LayerNormProcess() + + def construct(self, act, gamma, beta, mask=None): + '''construct''' + act = act + gamma = gamma + beta = beta + # print("debug MaskedLayerNorm act", act) + ones = P.Ones()(act.shape[:-1] + (1,), act.dtype) + if mask is not None: + # mask = F.expand_dims(mask, -1) + # mask = mask * ones + mask = self.expand(mask, -1) + mask = self.mul(mask, ones) + else: + mask = ones + # print("debug MaskedLayerNorm mask", mask) + # act = act * mask + # act = self.norm(act, gamma, beta) + # act = act * mask + # print("debug MaskedLayerNorm act 54", act) + + act = self.mul(act, mask) + act, _, _ = self.norm(act, gamma, beta) + act = self.mul(act, mask) + return act \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/msa.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/msa.py new file mode 100644 index 0000000000000000000000000000000000000000..f8cf13d99ea55cd1cbaad9ced46afcf8006ab63c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/msa.py @@ -0,0 +1,418 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MSA""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from .basic import Attention, GlobalAttention +from .mask import MaskedLayerNormParallel# MaskedLayerNorm# +# from .dense import AddInterface, ProcessSBR +from .sbr import ProcessSBR +from .interface import AddInterface +from ..common.utils import _memory_reduce, MemoryReduceCell + + +class MSARowAttentionWithPairBias(nn.Cell): + r""" + MSA row attention. Information from pair action value is made as the bias of the matrix of MSARowAttention, + in order to update the state of MSA using pair information. + + Reference: + `Jumper et al. (2021) Suppl. Alg. 7 'MSARowAttentionWithPairBias' + `_. + + Args: + num_head (int): The number of the attention head. + key_dim (int): The dimension of the attention hidden layer. + gating (bool): Indicator of if the attention is gated. + msa_act_dim (int): The dimension of the msa_act. + pair_act_dim (int): The dimension of the pair_act. + batch_size (int): The batch size of parameters in MSA row attention, used in while control flow. + Default: None. + slice_num (int): The number of slices to be made to reduce memory. Default: 0. + + Inputs: + - **msa_act** (Tensor) - Tensor of msa_act with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . + - **msa_mask** (Tensor) - The mask for MSA row attention matrix with shape :math:`(N_{seqs}, N_{res})` . + - **pair_act** (Tensor) - Tensor of pair_act with shape :math:`(N_{res}, N_{res}, pair\_act\_dim)` . + Data type is float. + - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". + - **norm_msa_mask** (Tensor) - The mask of msa_act when to do layernorm with shape :math:`(N_{seqs}, N_{res})`, + Default: "None". + - **norm_pair_mask** (Tensor) - The mask of pair_act when to do layernorm with shape :math:`(N_{res}, N_{res})`, + Default: "None". + - **res_idx** (Tensor) - The residue index used to perform ROPE with shape :math:`(N_{res})`, Default: "None". + + Outputs: + Tensor, the float tensor of the msa_act of the layer with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import MSARowAttentionWithPairBias + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = MSARowAttentionWithPairBias(num_head=4, key_dim=4, gating=True, + ... msa_act_dim=64, pair_act_dim=128, + ... batch_size=None) + >>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32) + >>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16) + >>> pair_act = Tensor(np.ones((256, 256, 128)), mstype.float32) + >>> index = None + >>> msa_out = model(msa_act, msa_mask, pair_act, index) + >>> print(msa_out.shape) + (4, 256, 64) + """ + + def __init__(self, num_head, key_dim, gating, msa_act_dim, pair_act_dim, device_num, batch_size=None, slice_num=0, is_extra_msa=False): + super(MSARowAttentionWithPairBias, self).__init__() + + self.num_head = num_head + self.batch_size = batch_size + # self.matmul = P.MatMul(transpose_b=True) + self.batch_matmul = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1))) + self.attn_mod = Attention(num_head, key_dim, gating, msa_act_dim, msa_act_dim, msa_act_dim, device_num, batch_size) + self.msa_act_dim = msa_act_dim + self.pair_act_dim = pair_act_dim + self.batch_size = batch_size + self.slice_num = slice_num + self.idx = Tensor(0, mstype.int32) + # self.masked_layer_norm = MaskedLayerNorm() + self.masked_layer_norm = MaskedLayerNormParallel(device_num) + self.is_extra_msa = is_extra_msa + if not is_extra_msa: + self.add_interface = AddInterface(msa_act_dim, batch_size) + self.process_sbr = ProcessSBR(128, num_head, batch_size=batch_size) + self._init_parameter() + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + # concat = [] + # for i in range(slice_num): + # concat.append((1, device_num, 1)) + self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((1, device_num, 1), (1,)), ((1, 1, 1, 1), (1,))]) + # self.memory_reduce.concat.shard(tuple(concat)) + + def construct(self, msa_act, msa_mask, pair_act, sbr_act, sbr_mask, interface_mask, index=None, norm_msa_mask=None, norm_pair_mask=None, res_idx=None): + '''construct''' + + if self.batch_size: + query_norm_gamma = self.gather2(self.query_norm_gammas, index, 0) + query_norm_beta = self.gather2(self.query_norm_betas, index, 0) + feat_2d_norm_gamma = self.gather2(self.feat_2d_norm_gammas, index, 0) + feat_2d_norm_beta = self.gather2(self.feat_2d_norm_betas, index, 0) + feat_2d_weight = self.gather3(self.feat_2d_weights, index, 0) + + else: + query_norm_gamma = self.query_norm_gammas + query_norm_beta = self.query_norm_betas + feat_2d_norm_gamma = self.feat_2d_norm_gammas + feat_2d_norm_beta = self.feat_2d_norm_betas + feat_2d_weight = self.feat_2d_weights + + + + q, k, _ = pair_act.shape + input_bias = 1e9 * (msa_mask - 1.0) + input_bias = P.ExpandDims()(P.ExpandDims()(input_bias, 1), 2) + if not self.is_extra_msa: + msa_act += self.add_interface(interface_mask, msa_act, index=index) + + msa_act = self.masked_layer_norm(msa_act, query_norm_gamma, query_norm_beta, mask=norm_msa_mask) + pair_act = self.masked_layer_norm(pair_act, feat_2d_norm_gamma, feat_2d_norm_beta, mask=norm_pair_mask) + # pair_act = P.Reshape()(pair_act, (-1, pair_act.shape[-1])) + # nonbatched_bias = P.Reshape()(self.matmul(pair_act, feat_2d_weight), (q, k, self.num_head)) + nonbatched_bias = self.batch_matmul(pair_act, feat_2d_weight) + if not self.is_extra_msa: + nonbatched_bias += self.process_sbr(sbr_act, sbr_mask, index=index) + nonbatched_bias = P.Transpose()(nonbatched_bias, (2, 0, 1)) + + batched_inputs = (msa_act, input_bias) + if res_idx is not None: + nonbatched_inputs = (nonbatched_bias, res_idx) + else: + nonbatched_inputs = (index, nonbatched_bias) + # msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + msa_act = self.memory_reduce(self._compute, batched_inputs, nonbatched_inputs) + return msa_act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) + self.feat_2d_norm_gammas = Parameter( + Tensor(np.zeros([self.batch_size, self.pair_act_dim]), mstype.float32)) + self.feat_2d_norm_betas = Parameter( + Tensor(np.zeros([self.batch_size, self.pair_act_dim]), mstype.float32)) + self.feat_2d_weights = Parameter( + Tensor(np.zeros([self.batch_size, self.num_head, self.pair_act_dim]), mstype.float32)) + + else: + self.query_norm_gammas = Parameter(Tensor(np.ones([self.msa_act_dim]), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros([self.msa_act_dim]), mstype.float32)) + self.feat_2d_norm_gammas = Parameter(Tensor(np.ones([self.pair_act_dim]), mstype.float32)) + self.feat_2d_norm_betas = Parameter(Tensor(np.zeros([self.pair_act_dim]), mstype.float32)) + self.feat_2d_weights = Parameter( + Tensor(np.random.normal(scale=1 / np.sqrt(self.pair_act_dim), size=[self.num_head, self.pair_act_dim]), + mstype.float32)) + + + def _compute(self, msa_act, mask, index, nonbatched_bias): + """ + compute. + + Args: + msa_act (Tensor): Tensor of msa_act. + mask (Tensor): The mask for MSA row attention matrix. + index (Tensor): The index of while loop, only used in case of while control flow. Default: None + nonbatched_bias(Tensor): Tensor of non batched bias matrix. + + Outputs: + - **msa_act** (Tensor)- Tensor, the float tensor of the msa_act of the attention layer. + """ + msa_act = self.attn_mod(msa_act, msa_act, mask, index, nonbatched_bias) + return msa_act + + +class MSAColumnAttention(nn.Cell): + """ + MSA column-wise gated self attention. + The column-wise attention lets the elements that belong to the same target residue exchange information. + + Reference: + `Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" + `_. + + Args: + num_head (int): The number of the heads. + key_dim (int): The dimension of the input. + gating (bool): Indicator of if the attention is gated. + msa_act_dim (int): The dimension of the msa_act. The intermediate variable after MSA retrieving + in AlphaFold. + batch_size (int): The batch size of parameters in MSAColumnAttention, used in while control flow, + Default: "None". + slice_num (int): The number of slices to be made to reduce memory, Default: 0. + + Inputs: + - **msa_act** (Tensor) - Tensor of msa_act. The intermediate variable after MSA retrieving + in AlphaFold, shape :math:`[N_{seqs}, N_{res}, C_m]` . + - **msa_mask** (Tensor) - The mask for MSAColumnAttention matrix, shape :math:`[N_{seqs}, N_{res}]`. + - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". + + Outputs: + Tensor, the float tensor of the msa_act of the layer, shape :math:`[N_{seqs}, N_{res}, C_m]`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import MSAColumnAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = MSAColumnAttention(num_head=8, key_dim=256, gating=True, + ... msa_act_dim=256, batch_size=1, slice_num=0) + >>> msa_act = Tensor(np.ones((512, 256, 256)), mstype.float32) + >>> msa_mask = Tensor(np.ones((512, 256)), mstype.float32) + >>> index = Tensor(0, mstype.int32) + >>> attn_out = model(msa_act, msa_mask, index) + >>> print(attn_out.shape) + (512, 256, 256) + """ + + def __init__(self, num_head, key_dim, gating, msa_act_dim, device_num, batch_size=None, slice_num=0): + super(MSAColumnAttention, self).__init__() + # self.query_norm = MaskedLayerNorm() + self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5).shard(((1, device_num, 1), (1,), (1,))) + self.attn_mod = Attention(num_head, key_dim, gating, msa_act_dim, msa_act_dim, msa_act_dim, device_num, batch_size) + + self.batch_size = batch_size + self.slice_num = slice_num + # concat = [] + # for i in range(slice_num): + # concat.append((device_num, 1, 1)) + # self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((device_num, 1, 1), (1,)), ((1, 1, 1, 1), (1,))]) + # self.memory_reduce.concat.shard(tuple(concat)) + self.msa_act_dim = msa_act_dim + self.idx = Tensor(0, mstype.int32) + self._init_parameter() + self.gather2 = P.Gather().shard(((1, 1), ())) + + def construct(self, msa_act, msa_mask, index=None): + '''construct''' + if self.batch_size: + query_norm_gamma = self.gather2(self.query_norm_gammas, index, 0) + query_norm_beta = self.gather2(self.query_norm_betas, index, 0) + else: + query_norm_gamma = self.query_norm_gammas + query_norm_beta = self.query_norm_betas + msa_act = P.Transpose()(msa_act, (1, 0, 2)) + msa_mask = P.Transpose()(msa_mask, (1, 0)) + + input_mask = 1e9 * (msa_mask - 1.) + input_mask = P.ExpandDims()(P.ExpandDims()(input_mask, 1), 2) + # msa_act = self.query_norm(msa_act, query_norm_gamma, query_norm_beta) + msa_act, _, _ = self.query_norm(msa_act, query_norm_gamma, query_norm_beta) + batched_inputs = (msa_act, input_mask) + nonbatched_inputs = (index,) + msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + # msa_act = self.memory_reduce(self._compute, batched_inputs, nonbatched_inputs) + msa_act = P.Transpose()(msa_act, (1, 0, 2)) + return msa_act + + def _init_parameter(self): + if self.batch_size: + self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) + else: + self.query_norm_gammas = Parameter(Tensor(np.ones([self.msa_act_dim]), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros([self.msa_act_dim]), mstype.float32)) + + def _compute(self, msa_act, input_mask, index): + '''compute''' + msa_act = self.attn_mod(msa_act, msa_act, input_mask, index) + return msa_act + + +class MSAColumnGlobalAttention(nn.Cell): + r""" + MSA column global attention. Transpose MSA information at sequence axis and residue axis, then use `GlobalAttention + ` to + do Attention between input sequences without dealing with the relationship between residues in sequence. + Comparing with MSAColumnAttention, it uses GlobalAttention to deal with longer input sequence. + + Reference: + `Jumper et al. (2021) Suppl. Alg. 19 'MSAColumnGlobalAttention' + `_. + + Args: + num_head (int): The number of the attention heads. + gating (bool): Indicator of if the attention is gated. + msa_act_dim (int): The dimension of the msa_act. + batch_size (int): The batch size of parameters in MSAColumnGlobalAttention, used + in while control flow. Default: None. + slice_num (int): The number of slices to be made to reduce memory. Default: 0 + + Inputs: + - **msa_act** (Tensor) - Tensor of msa_act with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . + - **msa_mask** (Tensor) - The mask for msa_act matrix with shape :math:`(N_{seqs}, N_{res})` . + - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". + + Outputs: + Tensor, the float tensor of the msa_act of the layer with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import MSAColumnGlobalAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = MSAColumnGlobalAttention(num_head=4, gating=True, msa_act_dim=64, batch_size=None) + >>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32) + >>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16) + >>> index = None + >>> msa_out = model(msa_act, msa_mask, index) + >>> print(msa_out.shape) + (4, 256, 64) + """ + + def __init__(self, num_head, gating, msa_act_dim, device_num, batch_size=None, slice_num=0): + super(MSAColumnGlobalAttention, self).__init__() + self.attn_mod = GlobalAttention(num_head, gating, msa_act_dim, msa_act_dim, device_num, batch_size) + + # mask = None, not use MaskedLayerNorm() + # self.query_norm = MaskedLayerNorm() + self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + self.batch_size = batch_size + self.slice_num = slice_num + self.msa_act_dim = msa_act_dim + self.idx = Tensor(0, mstype.int32) + self.trans2 = P.Transpose().shard(((1, device_num, 1),)) + # concat = [] + # for i in range(slice_num): + # concat.append((1, device_num, 1)) + self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((1, device_num, 1), (1,)), ((1, device_num, 1), (1,))]) + # self.memory_reduce.concat.shard(tuple(concat)) + self.trans = P.Transpose().shard(((1, device_num),)) + self._init_parameter() + self.gather2 = P.Gather().shard(((1, 1), ())) + + def construct(self, msa_act, msa_mask, index=None): + '''construct''' + if self.batch_size: + query_norm_gamma = self.gather2(self.query_norm_gammas, index, 0) + query_norm_beta = self.gather2(self.query_norm_betas, index, 0) + # msa_act = P.Transpose()(msa_act, (1, 0, 2)) + # msa_mask = P.Transpose()(msa_mask, (1, 0)) + msa_act = self.trans2(msa_act, (1, 0, 2)) + msa_mask = self.trans(msa_mask, (1, 0)) + else: + query_norm_gamma = self.query_norm_gammas + query_norm_beta = self.query_norm_betas + # msa_act = P.Transpose()(msa_act, (1, 0, 2)) + # msa_mask = P.Transpose()(msa_mask, (1, 0)) + msa_act = self.trans2(msa_act, (1, 0, 2)) + msa_mask = self.trans(msa_mask, (1, 0)) + + # input_mask not use, notion in 20250208 + # input_mask = 1e9 * (msa_mask - 1.) + # input_mask = P.ExpandDims()(P.ExpandDims()(input_mask, 1), 2) + + msa_act, _, _ = self.query_norm(msa_act, + query_norm_gamma, + query_norm_beta) + # msa_act = self.query_norm(msa_act, + # query_norm_gamma, + # query_norm_beta) + msa_mask = P.ExpandDims()(msa_mask, -1) + batched_inputs = (msa_act, msa_mask) + nonbatched_inputs = (index,) + msa_act = self.memory_reduce(self._compute, batched_inputs, nonbatched_inputs) + # msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + # msa_act = P.Transpose()(msa_act, (1, 0, 2)) + msa_act = self.trans2(msa_act, (1, 0, 2)) + return msa_act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32)) + else: + self.query_norm_gammas = Parameter(Tensor(np.ones((self.msa_act_dim)), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros((self.msa_act_dim)), mstype.float32)) + + def _compute(self, msa_act, msa_mask, index): + """ + compute. + + Args: + msa_act (Tensor): Tensor of msa_act. + msa_mask (Tensor): The mask for msa_act matrix. + index (Tensor): The index of while loop, only used in case of while + control flow. Default: None + + Outputs: + - **msa_act** (Tensor)- Tensor, the float tensor of the msa_act of the attention layer. + """ + msa_act = self.attn_mod(msa_act, msa_act, msa_mask, index) + return msa_act \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/sbr.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/sbr.py new file mode 100644 index 0000000000000000000000000000000000000000..3882053d2535759bd1d70ac5597ac6dabc05c054 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/sbr.py @@ -0,0 +1,91 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Soft blurred restraints""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore import Parameter +from mindspore.ops import operations as P +# from .mask import MaskedLayerNorm + +def lecun_normal(dim_in, shape): + stddev = 1./np.sqrt(dim_in) + return np.random.normal(loc=0, scale=stddev, size=shape) + +class ProcessSBR(nn.Cell): + '''add inter-residue soft blurred restraints into pair representation''' + + def __init__(self, input_dim, output_dim, batch_size=None): + super(ProcessSBR, self).__init__() + self.matmul = P.MatMul(transpose_b=True) + self.input_dim = input_dim + self.output_dim = output_dim + self.batch_size = batch_size + + self.relu = nn.ReLU() + self.layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + self._init_parameter() + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + + def construct(self, act, mask=None, index=None, useperm=False): + '''Compute linear''' + linear_bias=None + if self.batch_size: + input_layer_norm_gamma = self.gather2(self.input_layer_norm_gammas, index, 0) + input_layer_norm_beta = self.gather2(self.input_layer_norm_betas, index, 0) + linear_weight = self.gather3(self.linear_weights, index, 0) + linear_bias = self.gather2(self.linear_biases, index, 0) + else: + input_layer_norm_gamma = self.input_layer_norm_gammas + input_layer_norm_beta = self.input_layer_norm_betas + linear_weight = self.linear_weights + linear_bias = self.linear_biases + act, _, _ = self.layer_norm(act, input_layer_norm_gamma, input_layer_norm_beta) + + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + act = P.BiasAdd()(self.matmul(act, linear_weight), linear_bias) + + + act = P.Reshape()(act, act_shape[:-1]+(-1,)) + if mask is not None: + if not useperm: + act *= P.ExpandDims()(mask, -1) + else: + act = P.Transpose()(act, (2, 0, 1)) + act *= mask[None, :, :] + act = P.Transpose()(act, (1, 2, 0)) + return act + + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.input_layer_norm_gammas = Parameter( + Tensor(np.ones((self.batch_size, self.input_dim)), mstype.float32)) + self.input_layer_norm_betas = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + self.linear_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.output_dim, self.input_dim)), mstype.float32)) + self.linear_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.output_dim)), mstype.float32)) + else: + self.input_layer_norm_gammas = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32)) + self.input_layer_norm_betas = Parameter(Tensor(np.zeros((self.input_dim)), mstype.float32)) + self.linear_weights = Parameter(Tensor(np.zeros((self.output_dim, self.input_dim)), mstype.float32)) + self.linear_biases = Parameter(Tensor(np.zeros((self.output_dim, )), mstype.float32)) \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/transition.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/transition.py new file mode 100644 index 0000000000000000000000000000000000000000..7593ba64ca57db83602a5ff1f8651d89c2db1ec2 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/transition.py @@ -0,0 +1,157 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transition""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore import Parameter +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.initializer import initializer +from .initializer import lecun_init +from .mask import MaskedLayerNormParallel#MaskedLayerNorm# +from ..common.utils import _memory_reduce, MemoryReduceCell + + +class Transition(nn.Cell): + r""" + This is 2-layer MLP where the intermediate layer expands number of channels + of the input by a factor(num_intermediate_factor). + + .. math:: + Transition(\mathbf{act}) = Linear(Linear(\mathbf{act})) + + Args: + num_intermediate_factor(float): The expand factor of intermediate output + channels compared to the input. + input_dim(int): The channels of the input. + batch_size(int): The batch size of parameters in Transition, + used in while control flow. Default: "None". + slice_num (int): The slice num used in transition layer + when the memory is overflow. Default: 0. + + Inputs: + - **act** (Tensor) - The input with channels equal to input_dim, shape is (..., input_dim). + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. Default: "None". + - **mask** (Tensor) - The mask of act when to do layernorm with shape :math:`(32, input_{dim})`, + Default: "None". + + Outputs: + Tensor, the float tensor of the output of the layer with shape (..., input_dim). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import Transition + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = Transition(num_intermediate_factor=4, input_dim=128) + >>> input = Tensor(np.ones((32, 128, 128)), mstype.float32) + >>> output= model(input) + >>> print(output.shape) + (32, 128, 128) + """ + + def __init__(self, num_intermediate_factor, input_dim, device_num, batch_size=None, slice_num=0): + super(Transition, self).__init__() + self.matmul = P.MatMul(transpose_b=True) + # self.batch_matmul = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1))) + # self.biasadd = P.Add().shard(((1, device_num, 1), (1,))) + self.input_dim = input_dim + self.num_intermediate = int(input_dim * num_intermediate_factor) + self.batch_size = batch_size + self.slice_num = slice_num + self.relu = nn.ReLU() + self.idx = Tensor(0, mstype.int32) + # self.masked_layer_norm = MaskedLayerNorm() + self.masked_layer_norm = MaskedLayerNormParallel(device_num) + # self.masked_layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5).shard(((1, device_num, 1), (1,), (1,))) + self._init_parameter() + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + # concat = [] + # for i in range(slice_num): + # concat.append((1, device_num, 1)) + self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((1, device_num, 1), (1,)),]) + # self.memory_reduce.concat.shard(tuple(concat)) + # self.mul = P.Mul().shard(((1, device_num, 1), (1, device_num, 1))) + + + def construct(self, act, index=None, mask=None): + '''Compute transition''' + if self.batch_size: + input_layer_norm_gamma = self.gather2(self.input_layer_norm_gammas, index, 0) + input_layer_norm_beta = self.gather2(self.input_layer_norm_betas, index, 0) + transition1_weight = self.gather3(self.transition1_weights, index, 0) + transition1_bias = self.gather2(self.transition1_biases, index, 0) + transition2_weight = self.gather3(self.transition2_weights, index, 0) + transition2_bias = self.gather2(self.transition2_biases, index, 0) + else: + input_layer_norm_gamma = self.input_layer_norm_gammas + input_layer_norm_beta = self.input_layer_norm_betas + transition1_weight = self.transition1_weights + transition1_bias = self.transition1_biases + transition2_weight = self.transition2_weights + transition2_bias = self.transition2_biases + + + act = self.masked_layer_norm(act, input_layer_norm_gamma, input_layer_norm_beta, mask=mask) + batched_inputs = (act,) + nonbatched_inputs = (transition1_weight, transition1_bias, transition2_weight, transition2_bias) + # act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + act = self.memory_reduce(self._compute, batched_inputs, nonbatched_inputs) + return act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.input_layer_norm_gammas = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + self.input_layer_norm_betas = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + self.transition1_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate, self.input_dim)), mstype.float32)) + self.transition1_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate)), mstype.float32)) + self.transition2_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim, self.num_intermediate)), mstype.float32)) + self.transition2_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + else: + self.input_layer_norm_gammas = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32)) + self.input_layer_norm_betas = Parameter(Tensor(np.zeros((self.input_dim)), mstype.float32)) + self.transition1_weights = Parameter(initializer(lecun_init(self.input_dim, initializer_name='relu'), + [self.num_intermediate, self.input_dim])) + self.transition1_biases = Parameter(Tensor(np.zeros((self.num_intermediate)), mstype.float32)) + self.transition2_weights = Parameter( + Tensor(np.zeros((self.input_dim, self.num_intermediate)), mstype.float32)) + self.transition2_biases = Parameter(Tensor(np.zeros((self.input_dim)), mstype.float32)) + + def _compute(self, act, transition1_weight, transition1_bias, transition2_weight, transition2_bias): + '''compute transition.''' + + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + act = self.relu(P.BiasAdd()(self.matmul(act, transition1_weight), transition1_bias)) + act = P.BiasAdd()(self.matmul(act, transition2_weight), transition2_bias) + act = P.Reshape()(act, act_shape) + # act = self.relu(self.biasadd(self.batch_matmul(act, transition1_weight), transition1_bias)) + # act = self.biasadd(self.batch_matmul(act, transition2_weight), transition2_bias) + return act \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/cell/triangle.py b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/triangle.py new file mode 100644 index 0000000000000000000000000000000000000000..95ee1203f13d156bd92bc9e8726bd6ebada4acb0 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/cell/triangle.py @@ -0,0 +1,681 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Triangle""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from .basic import Attention, Attention2 +from .initializer import lecun_init +from .mask import MaskedLayerNorm +from ..common.utils import _memory_reduce, MemoryReduceCell + + +class TriangleAttention(nn.Cell): + r""" + Triangle attention. for the detailed implementation process, refer to + `TriangleAttention `_. + + The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, + which is divided into three parts: projection, self-attention and output. Firstly, the amino acid pair is projected + to obtain the q, k, v, and then through the classic multi-head self-attention mechanism, add the relationship + between i, j, k triangle sides, finally output the result. + + Args: + orientation (int): Decide the dimension of Triangle attention, used as the starting and ending + edge of self-attention. + num_head (int): The number of the heads. + key_dim (int): The dimension of the hidden layer. + gating (bool): Indicator of if the attention is gated. + layer_norm_dim (int): The dimension of the layer_norm. + batch_size (int): The batch size of triangle attention, default: "None". + slice_num (int): The number of slices to be made to reduce memory, default: 0. + + Inputs: + - **pair_act** (Tensor) - Tensor of pair_act. shape :math:`(N_{res}, N_{res}, layer\_norm\_dim)` + - **pair_mask** (Tensor) - The mask for TriangleAttention matrix with shape. shape :math:`(N_{res}, N_{res})`. + - **index** (Tensor) - The index of while loop, only used in case of while control flow, Default: "None". + - **mask** (Tensor) - The mask of pair_act when to do layernorm with shape (N_{res}, N_{res}), Default: "None". + + Outputs: + Tensor, the float tensor of the pair_act of the layer with shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import TriangleAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = TriangleAttention(orientation="per_row", num_head=4, key_dim=64, gating=True, layer_norm_dim=64) + >>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32) + >>> input_1 = Tensor(np.ones((256, 256)), mstype.float32) + >>> out = model(input_0, input_1, index=0) + >>> print(out.shape) + (256, 256, 64) + """ + + def __init__(self, orientation, num_head, key_dim, gating, layer_norm_dim, device_num, batch_size=None, slice_num=0): + super(TriangleAttention, self).__init__() + self.num_head = num_head + self.orientation = orientation + self.orientation_is_per_column = (self.orientation == 'per_column') + self.init_factor = Tensor(1. / np.sqrt(layer_norm_dim), mstype.float32) + self.matmul = P.MatMul(transpose_b=True) + self.slice_num = slice_num + # concat = [] + # for i in range(slice_num): + # concat.append((1, device_num, 1)) + # self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((1, device_num, 1), (1,)), None]) + # self.memory_reduce.concat.shard(tuple(concat)) + if self.orientation_is_per_column: + + # self.slice_num = slice_num + concat = [] + for i in range(slice_num): + concat.append((device_num, 1, 1)) + self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((device_num, 1, 1), (1,)), ((1, 1, 1), (1,))], dim=1, gather_dim=1) + self.memory_reduce.concat.shard(tuple(concat)) + + + # self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((1, device_num, 1), (1,)), ((1, 1, 1, 1), (1,))]) + # self.memory_reduce.concat.shard(tuple(concat)) + + self.layernorm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5).shard(((device_num, 1, 1), (1,), (1,))) + self.batchmatmul_b2 = P.BatchMatMul(transpose_b=True).shard(((device_num, 1, 1), (1, 1))) + self.mul = P.Mul() + self.sub = P.Sub() + self.expand = P.ExpandDims() + self.expand2 = P.ExpandDims() + self.trans3 = P.Transpose().shard(((device_num, 1, 1),)) + self.attn_mod = Attention2(num_head, key_dim, gating, layer_norm_dim, layer_norm_dim, layer_norm_dim, + device_num, batch_size) + else: + + self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((1, device_num, 1), (1,)), ((1, 1, 1, 1), (1,))]) + # self.memory_reduce.concat.shard(tuple(concat)) + + self.layernorm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5).shard(((1, device_num, 1), (1,), (1,))) + self.batchmatmul_b2 = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1))) + self.mul = P.Mul().shard(((), (1, device_num))) + self.sub = P.Sub().shard(((1, device_num), ())) + self.expand = P.ExpandDims().shard(((1, device_num),)) + self.expand2 = P.ExpandDims().shard(((1, 1, device_num),)) + self.trans3 = P.Transpose().shard(((1, device_num, 1),)) + self.attn_mod = Attention(num_head, key_dim, gating, layer_norm_dim, layer_norm_dim, layer_norm_dim, + device_num, batch_size) + + self.batchmatmul_b = P.BatchMatMul(transpose_b=True) + # self.attn_mod = Attention(num_head, key_dim, gating, layer_norm_dim, layer_norm_dim, layer_norm_dim, + # device_num, batch_size) + self.batch_size = batch_size + self.slice_num = slice_num + self.layer_norm_dim = layer_norm_dim + self.idx = Tensor(0, mstype.int32) + self.masked_layer_norm = MaskedLayerNorm() + self._init_parameter() + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + self.trans = P.Transpose().shard(((1, device_num),)) + self.trans2 = P.Transpose().shard(((1, device_num, 1),)) + + + def construct(self, pair_act, pair_mask, index=None, mask=None): + '''construct''' + if self.batch_size: + query_norm_gamma = self.gather2(self.query_norm_gammas, index, 0) + query_norm_beta = self.gather2(self.query_norm_betas, index, 0) + feat_2d_weight = self.gather3(self.feat_2d_weights, index, 0) + else: + query_norm_gamma = self.query_norm_gammas + query_norm_beta = self.query_norm_betas + feat_2d_weight = self.feat_2d_weights + if self.orientation_is_per_column: + # pair_act = P.Transpose()(pair_act, (1, 0, 2)) + pair_act = self.trans2(pair_act, (1, 0, 2)) + # pair_mask = P.Transpose()(pair_mask, (1, 0)) + pair_mask = self.trans(pair_mask, (1, 0)) + + + # Fix Bug + # pair_act = self.masked_layer_norm(pair_act, query_norm_gamma, query_norm_beta, mask) + + pair_act, _, _ = self.layernorm(pair_act, + query_norm_gamma, + query_norm_beta) + + q, k, _ = pair_act.shape + # nonbatched_bias = self.matmul(P.Reshape()(pair_act, (-1, pair_act.shape[-1])), feat_2d_weight) + nonbatched_bias = self.batchmatmul_b2(pair_act, feat_2d_weight) + # nonbatched_bias = P.Transpose()(P.Reshape()(nonbatched_bias, (q, k, -1)), (2, 0, 1)) + nonbatched_bias = self.trans3(P.Reshape()(nonbatched_bias, (q, k, -1)), (2, 0, 1)) #(1, 8, 1) + + # pair_mask = 1e9 * (pair_mask - 1.) + # input_mask = P.ExpandDims()(P.ExpandDims()(pair_mask, 1), 2) + pair_mask = self.mul(1e9, self.sub(pair_mask, 1.)) + input_mask = self.expand2(self.expand(pair_mask, 1), 2) + + # batched_inputs = (pair_act, input_mask) + # nonbatched_inputs = (index, nonbatched_bias) + # # pair_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + # pair_act = self.memory_reduce(self._compute, batched_inputs, nonbatched_inputs) + # if self.orientation_is_per_column: + # pair_act = self.trans2(pair_act, (1, 0, 2)) + # # pair_act = P.Transpose()(pair_act, (1, 0, 2)) + # return pair_act + + if self.orientation_is_per_column: + batched_inputs = (pair_act, nonbatched_bias) + nonbatched_inputs = (input_mask, pair_act, index) + pair_act = self.memory_reduce(self._compute_column, batched_inputs, nonbatched_inputs) + pair_act = self.trans2(pair_act, (1, 0, 2)) + else: + batched_inputs = (pair_act, input_mask) + nonbatched_inputs = (index, nonbatched_bias) + pair_act = self.memory_reduce(self._compute, batched_inputs, nonbatched_inputs) + return pair_act + + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.feat_2d_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_head, self.layer_norm_dim)), mstype.float32)) + else: + self.query_norm_gammas = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.feat_2d_weights = Parameter(Tensor( + np.random.normal(scale=1 / np.sqrt(self.layer_norm_dim), size=(self.num_head, self.layer_norm_dim)), + mstype.float32)) + + def _compute(self, pair_act, input_mask, index, nonbatched_bias): + '''compute traiangle''' + pair_act = self.attn_mod(pair_act, pair_act, input_mask, index, nonbatched_bias) + return pair_act + def _compute_column(self, pair_act, nonbatched_bias, input_mask, pair_act_kv, index): + '''compute traiangle''' + pair_act = self.attn_mod(pair_act, pair_act_kv, input_mask, index, nonbatched_bias) + return pair_act + + +class TriangleMultiplication(nn.Cell): + r""" + Triangle multiplication layer. for the detailed implementation process, refer to + `TriangleMultiplication `_. + + The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, and + the result of the dot product between ik and jk is added to the edge of ij. + + Args: + num_intermediate_channel (float): The number of intermediate channel. + equation (str): The equation used in triangle multiplication layer. edge update forms + corresponding to 'incoming' and 'outgoing', + :math:`(ikc,jkc->ijc, kjc,kic->ijc)`. + layer_norm_dim (int): The last dimension length of the layer norm. + batch_size (int): The batch size of parameters in triangle multiplication. Default: None. + + Inputs: + - **pair_act** (Tensor) - Tensor of pair_act. shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + - **pair_mask** (Tensor) - The mask for TriangleAttention matrix with shape. shape :math:`(N{res}, N{res})`. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. + + Outputs: + Tensor, the float tensor of the pair_act of the layer with shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import TriangleMultiplication + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = TriangleMultiplication(num_intermediate_channel=64, + ... equation="ikc,jkc->ijc", layer_norm_dim=64, batch_size=0) + >>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32) + >>> input_1 = Tensor(np.ones((256, 256)), mstype.float32) + >>> out = model(input_0, input_1, index=0) + >>> print(out.shape) + (256, 256, 64) + """ + + def __init__(self, num_intermediate_channel, equation, layer_norm_dim, device_num, batch_size=None): + super(TriangleMultiplication, self).__init__() + self.num_intermediate_channel = num_intermediate_channel + self.equation = equation + # self.layer_norm = MaskedLayerNorm() + self.layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5).shard(((1, device_num, 1), (1,), (1,))) + self.matmul = P.MatMul(transpose_b=True) + self.sigmoid = nn.Sigmoid() + self.sigmoid.sigmoid.shard(((1, device_num, 1),)) + self.batch_matmul_trans_b1 = P.BatchMatMul(transpose_b=True).shard(((1, 1, device_num), (1, 1, device_num))) + self.add = P.Add().shard(((1, device_num, 1), (1,))) + self.batch_matmul = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1))) + self.mul = P.Mul().shard(((1, device_num, 1), (1, device_num, 1))) + self.batch_matmul_trans_b2 = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1, 1))) + equation = ["ikc,jkc->ijc", "kjc,kic->ijc"] + if self.equation not in equation: + print("TriangleMultiplication Not Suppl") + if self.equation == "ikc,jkc->ijc": + self.equation = True + concat = [] + for i in range(device_num): + concat.append((1, 1, device_num)) + self.memory_reduce = MemoryReduceCell(device_num, device_num, strategy=[((1, 1, device_num), (1,)), ((1, 1, device_num), (1,))]) + self.memory_reduce.concat.shard(tuple(concat)) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True).shard(((1, 1, device_num), (1, 1, device_num))) + elif self.equation == "kjc,kic->ijc": + self.equation = False + self.memory_reduce = MemoryReduceCell(device_num, device_num, strategy=[((1, device_num, 1), (1,)), ((1, device_num, 1), (1,))]) + concat = [] + for i in range(device_num): + concat.append((1, device_num, 1)) + self.memory_reduce.concat.shard(tuple(concat)) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1, 1))) + else: + self.equation = None + self.batch_size = batch_size + self.layer_norm_dim = layer_norm_dim + self._init_parameter() + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + self.trans2 = P.Transpose().shard(((1, device_num, 1),)) + self.trans3 = P.Transpose().shard(((1, 1, device_num),)) + + def compute(self, left_proj_act_tmp, right_proj_act_tmp, index): + act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp) + return act + + def construct(self, act, mask, index=None): + r""" + Builds triangle multiplication module. + + Args: + act(Tensor): Pair activations. Data type is float. + mask(Tensor): Pair mask. Data type is float. + index(int): The index of the batch size when batch size is not none. + + Returns: + act(Tensor), the shape is same as act_shape[:-1]. + """ + + if self.batch_size: + layer_norm_input_gamma = self.gather2(self.layer_norm_input_gammas, index, 0) + layer_norm_input_beta = self.gather2(self.layer_norm_input_betas, index, 0) + left_projection_weight = self.gather3(self.left_projection_weights, index, 0) + left_projection_bias = self.gather2(self.left_projection_biases, index, 0) + right_projection_weight = self.gather3(self.right_projection_weights, index, 0) + right_projection_bias = self.gather2(self.right_projection_biases, index, 0) + left_gate_weight = self.gather3(self.left_gate_weights, index, 0) + left_gate_bias = self.gather2(self.left_gate_biases, index, 0) + right_gate_weight = self.gather3(self.right_gate_weights, index, 0) + right_gate_bias = self.gather2(self.right_gate_biases, index, 0) + center_layer_norm_gamma = self.gather2(self.center_layer_norm_gammas, index, 0) + center_layer_norm_beta = self.gather2(self.center_layer_norm_betas, index, 0) + # print("debug center_layer_norm_gamma", center_layer_norm_gamma) + # print("debug center_layer_norm_beta", center_layer_norm_beta) + output_projection_weight = self.gather3(self.output_projection_weights, index, 0) + output_projection_bias = self.gather2(self.output_projection_biases, index, 0) + gating_linear_weight = self.gather3(self.gating_linear_weights, index, 0) + gating_linear_bias = self.gather2(self.gating_linear_biases, index, 0) + else: + layer_norm_input_gamma = self.layer_norm_input_gammas + layer_norm_input_beta = self.layer_norm_input_betas + left_projection_weight = self.left_projection_weights + left_projection_bias = self.left_projection_biases + right_projection_weight = self.right_projection_weights + right_projection_bias = self.right_projection_biases + left_gate_weight = self.left_gate_weights + left_gate_bias = self.left_gate_biases + right_gate_weight = self.right_gate_weights + right_gate_bias = self.right_gate_biases + center_layer_norm_gamma = self.center_layer_norm_gammas + center_layer_norm_beta = self.center_layer_norm_betas + output_projection_weight = self.output_projection_weights + output_projection_bias = self.output_projection_biases + gating_linear_weight = self.gating_linear_weights + gating_linear_bias = self.gating_linear_biases + + mask = P.ExpandDims()(mask, -1) + # print("debug TriangleMultiplication mask", mask) + # act = self.layer_norm(act, + # layer_norm_input_gamma, + # layer_norm_input_beta) + # print("debug TriangleMultiplication act", act) + + act, _, _ = self.layer_norm(act, + layer_norm_input_gamma, + layer_norm_input_beta) + act_shape = P.Shape()(act) + # if len(act_shape) != 2: + # act = P.Reshape()(act, (-1, act_shape[-1])) + out_shape = act_shape[:-1] + (-1,) + input_act = act + # left_projection = P.BiasAdd()(self.matmul(act, left_projection_weight), left_projection_bias) + left_projection = self.add(self.batch_matmul(act, left_projection_weight), left_projection_bias) + + # left_gate_values = P.BiasAdd()(self.matmul(act, left_gate_weight), left_gate_bias) + left_gate_values = self.add(self.batch_matmul(act, left_gate_weight), left_gate_bias) + left_gate_values = self.sigmoid(left_gate_values) + # print("debug TriangleMultiplication left_gate_values", left_gate_values) + + # left_proj_act = left_projection * left_gate_values + left_proj_act = self.mul(left_projection, left_gate_values) + left_proj_act = P.Reshape()(left_proj_act, out_shape) + + # right_projection = P.BiasAdd()(self.matmul(act, right_projection_weight), right_projection_bias) + right_projection = self.add(self.batch_matmul(act, right_projection_weight), right_projection_bias) + # print("debug TriangleMultiplication right_projection", right_projection) + # right_gate_values = P.BiasAdd()(self.matmul(act, right_gate_weight), right_gate_bias) + right_gate_values = self.add(self.batch_matmul(act, right_gate_weight), right_gate_bias) + right_gate_values = self.sigmoid(right_gate_values) + + # right_proj_act = mask * P.Reshape()(right_projection * right_gate_values, out_shape) + right_proj_act = self.mul(mask, P.Reshape()(self.mul(right_projection, right_gate_values), out_shape)) + # print("debug TriangleMultiplication right_proj_act", right_proj_act) + if self.equation is not None: + if self.equation: + # left_proj_act_tmp = P.Transpose()(left_proj_act, (2, 0, 1)) + # right_proj_act_tmp = P.Transpose()(right_proj_act, (2, 0, 1)) + left_proj_act_tmp = self.trans2(left_proj_act, (2, 0, 1)) + right_proj_act_tmp = self.trans2(right_proj_act, (2, 0, 1)) + batched_inputs = (left_proj_act_tmp, right_proj_act_tmp,) + nonbatched_inputs = (right_proj_act_tmp,) + act = self.memory_reduce(self.compute, batched_inputs, nonbatched_inputs) + # act = self.batch_matmul_trans_b1(left_proj_act_tmp, right_proj_act_tmp) + # act = P.Transpose()(act, (1, 2, 0)) + act = self.trans3(act, (1, 2, 0)) + else: + left_proj_act_tmp = self.trans2(left_proj_act, (2, 1, 0)) + right_proj_act_tmp = self.trans2(right_proj_act, (2, 1, 0)) + batched_inputs = (left_proj_act_tmp, right_proj_act_tmp,) + nonbatched_inputs = (right_proj_act_tmp,) + act = self.memory_reduce(self.compute, batched_inputs, nonbatched_inputs) + # act = self.batch_matmul_trans_b2(left_proj_act_tmp, right_proj_act_tmp) + act = self.trans2(act, (2, 1, 0)) + # print("debug TriangleMultiplication act 290", act) + # print("debug TriangleMultiplication center_layer_norm_gamma", center_layer_norm_gamma) + # print("debug TriangleMultiplication center_layer_norm_beta", center_layer_norm_beta) + act, _, _ = self.layer_norm(act, + center_layer_norm_gamma, + center_layer_norm_beta) + # print("debug TriangleMultiplication act 296", act) + # if len(act_shape) != 2: + # act = P.Reshape()(act, (-1, act_shape[-1])) + + # act = P.BiasAdd()(self.matmul(act, output_projection_weight), output_projection_bias) + act = self.add(self.batch_matmul(act, output_projection_weight), output_projection_bias) + # gate_values = P.BiasAdd()(self.matmul(input_act, gating_linear_weight), gating_linear_bias) + gate_values = self.add(self.batch_matmul(input_act, gating_linear_weight), gating_linear_bias) + gate_values = self.sigmoid(gate_values) + # print("debug TriangleMultiplication gate_values", gate_values) + + # act = P.Reshape()(act * gate_values, out_shape) + + act = P.Reshape()(self.mul(act, gate_values), out_shape) + return act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.layer_norm_input_gammas = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.layer_norm_input_betas = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.left_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel, self.layer_norm_dim)), + mstype.float32)) + self.left_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel)), mstype.float32)) + self.right_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel, self.layer_norm_dim)), + mstype.float32)) + self.right_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel)), mstype.float32)) + self.left_gate_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel, self.layer_norm_dim)), + mstype.float32)) + self.left_gate_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel)), mstype.float32)) + self.right_gate_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel, self.layer_norm_dim)), + mstype.float32)) + self.right_gate_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel)), mstype.float32)) + self.center_layer_norm_gammas = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.center_layer_norm_betas = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.output_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) + self.output_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.gating_linear_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) + self.gating_linear_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + else: + self.layer_norm_input_gammas = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.left_projection_weights = Parameter(initializer(lecun_init(self.num_intermediate_channel), + [self.num_intermediate_channel, + self.layer_norm_dim])) + self.left_projection_biases = Parameter( + Tensor(np.zeros((self.num_intermediate_channel)), mstype.float32)) + self.right_projection_weights = Parameter(initializer(lecun_init(self.num_intermediate_channel), + [self.num_intermediate_channel, + self.layer_norm_dim])) + self.right_projection_biases = Parameter( + Tensor(np.zeros((self.num_intermediate_channel)), mstype.float32)) + self.left_gate_weights = Parameter( + Tensor(np.zeros((self.num_intermediate_channel, self.layer_norm_dim)), mstype.float32)) + self.left_gate_biases = Parameter(Tensor(np.ones((self.num_intermediate_channel)), mstype.float32)) + self.right_gate_weights = Parameter( + Tensor(np.zeros((self.num_intermediate_channel, self.layer_norm_dim)), mstype.float32)) + self.right_gate_biases = Parameter(Tensor(np.ones((self.num_intermediate_channel)), mstype.float32)) + self.center_layer_norm_gammas = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + self.center_layer_norm_betas = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.output_projection_weights = Parameter( + Tensor(np.zeros((self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) + self.output_projection_biases = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.gating_linear_weights = Parameter( + Tensor(np.zeros((self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) + self.gating_linear_biases = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + + +class OuterProductMean(nn.Cell): + r""" + Computing the correlation of the input tensor along its second dimension, the computed correlation + could be used to update the correlation features(e.g. the Pair representation). + + .. math:: + OuterProductMean(\mathbf{act}) = Linear(flatten(mean(\mathbf{act}\otimes\mathbf{act}))) + + Args: + num_outer_channel (float): The last dimension size of intermediate layer in OuterProductMean. + act_dim (int): The last dimension size of the input act. + num_output_channel (int): The last dimension size of output. + batch_size(int): The batch size of parameters in OuterProductMean, + used in while control flow. Default: "None". + slice_num (int): The slice num used in OuterProductMean layer + when the memory is overflow. Default: 0. + + Inputs: + - **act** (Tensor) - The input tensor with shape :math:`(dim_1, dim_2, act\_dim)`. + - **mask** (Tensor) - The mask for OuterProductMean with shape :math:`(dim_1, dim_2)`. + - **mask_norm** (Tensor) - Squared L2-norm along the first dimension of **mask**, + pre-computed to avoid re-computing, its shape is :math:`(dim_2, dim_2, 1)`. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. Default: "None". + + Outputs: + Tensor, the float tensor of the output of OuterProductMean layer with + shape :math:`(dim_2, dim_2, num\_output\_channel)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import OuterProductMean + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> model = OuterProductMean(num_outer_channel=32, act_dim=128, num_output_channel=256) + >>> act = Tensor(np.ones((32, 64, 128)), mstype.float32) + >>> mask = Tensor(np.ones((32, 64)), mstype.float32) + >>> mask_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(mask, mask), -1) + >>> output= model(act, mask, mask_norm) + >>> print(output.shape) + (64, 64, 256) + """ + + def __init__(self, num_outer_channel, act_dim, num_output_channel, device_num, batch_size=None, slice_num=0): + super(OuterProductMean, self).__init__() + self.num_output_channel = num_output_channel + self.num_outer_channel = num_outer_channel + # self.layer_norm_input = MaskedLayerNorm() + self.layer_norm_input = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5).shard(((1, device_num, 1), (1,), (1,))) + self.expand = P.ExpandDims().shard(((1, device_num),)) + self.matmul_trans_b = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1))) + self.batch_matmul = P.BatchMatMul(transpose_b=True).shard(((1, device_num, 1), (1, 1))) + self.bias_add2 = P.Add().shard(((1, device_num, 1), (1,))) + self.bias_add = P.Add().shard(((1, device_num, 1), (1,))) + self.matmul = P.MatMul().shard(((1, device_num), (device_num, 1))) + self.trans = P.Transpose().shard(((1, 1, device_num, 1),)) + self.div = P.RealDiv().shard(((1, device_num, 1), (1, device_num, 1))) + self.add = P.Add().shard(((), (1, device_num, 1))) + self.mul = P.Mul().shard(((1, device_num, 1), (1, device_num, 1))) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.act_dim = act_dim + self.batch_size = batch_size + self.slice_num = slice_num + self.idx = Tensor(0, mstype.int32) + # concat = [] + # for i in range(slice_num): + # concat.append((1, device_num, 1)) + self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, strategy=[((device_num, 1, 1), (1,)),], dim = 1) + # self.memory_reduce.concat.shard(tuple(concat)) + # concat = [] + # for i in range(slice_num): + # concat.append((1, device_num, 1)) + # self.memory_reduce = MemoryReduceCell(self.slice_num, device_num, dim=1, strategy=[((1, device_num, 1), (1,)),]) + # self.memory_reduce.concat.shard(tuple(concat)) + self._init_parameter() + self.gather3 = P.Gather().shard(((1, 1, 1), ())) + self.gather2 = P.Gather().shard(((1, 1), ())) + + def construct(self, act, mask, mask_norm, index=None): + """Compute outer product mean.""" + + if self.batch_size: + layer_norm_input_gamma = self.gather2(self.layer_norm_input_gammas, index, 0) + layer_norm_input_beta = self.gather2(self.layer_norm_input_betas, index, 0) + left_projection_weight = self.gather3(self.left_projection_weights, index, 0) + left_projection_bias = self.gather2(self.left_projection_biases, index, 0) + right_projection_weight = self.gather3(self.right_projection_weights, index, 0) + right_projection_bias = self.gather2(self.right_projection_biases, index, 0) + linear_output_weight = self.gather3(self.linear_output_weights, index, 0) + linear_output_bias = self.gather2(self.o_biases, index, 0) + else: + layer_norm_input_gamma = self.layer_norm_input_gammas + layer_norm_input_beta = self.layer_norm_input_betas + left_projection_weight = self.left_projection_weights + left_projection_bias = self.left_projection_biases + right_projection_weight = self.right_projection_weights + right_projection_bias = self.right_projection_biases + linear_output_weight = self.linear_output_weights + linear_output_bias = self.o_biases + # mask = P.ExpandDims()(mask, -1) + mask = self.expand(mask, -1) + # act = self.layer_norm_input(act, layer_norm_input_gamma, layer_norm_input_beta) + act, _, _ = self.layer_norm_input(act, layer_norm_input_gamma, layer_norm_input_beta) + act_shape = P.Shape()(act) + # if len(act_shape) != 2: + # act = P.Reshape()(act, (-1, act_shape[-1])) + out_shape = act_shape[:-1] + (-1,) + left_act = self.mul(mask, #P.Reshape()( + # left_act = mask * P.Reshape()( + # P.BiasAdd()(self.matmul_trans_b(act, left_projection_weight), left_projection_bias), out_shape) + self.bias_add(self.batch_matmul(act, left_projection_weight), left_projection_bias))#, out_shape) + right_act = self.mul(mask, #P.Reshape()( + # right_act = mask * P.Reshape()( + # P.BiasAdd()(self.matmul_trans_b(act, right_projection_weight), right_projection_bias), out_shape) + self.bias_add(self.batch_matmul(act, right_projection_weight), right_projection_bias))#, out_shape) + a, d, e = right_act.shape + right_act = P.Reshape()(right_act, (a, -1)) + batched_inputs = (left_act,) + nonbatched_inputs = (right_act, linear_output_weight, linear_output_bias, d, e) + # act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num, 1) + act = self.memory_reduce(self._compute, batched_inputs, nonbatched_inputs) + epsilon = 1e-3 + # act = P.RealDiv()(act, epsilon + mask_norm) + act = self.div(act, self.add(epsilon, mask_norm)) + return act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.layer_norm_input_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32)) + self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32)) + self.left_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_outer_channel, self.act_dim)), mstype.float32)) + self.left_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_outer_channel)), mstype.float32)) + self.right_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_outer_channel, self.act_dim)), mstype.float32)) + self.right_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_outer_channel)), mstype.float32)) + self.linear_output_weights = Parameter(Tensor(np.zeros( + (self.batch_size, self.num_output_channel, self.num_outer_channel * + self.num_outer_channel)), mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_output_channel)), mstype.float32)) + else: + self.layer_norm_input_gammas = Parameter(Tensor(np.ones((self.act_dim)), mstype.float32)) + self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.act_dim)), mstype.float32)) + self.left_projection_weights = Parameter( + initializer(lecun_init(self.act_dim), [self.num_outer_channel, self.act_dim])) + self.left_projection_biases = Parameter(Tensor(np.zeros((self.num_outer_channel)), mstype.float32)) + self.right_projection_weights = Parameter( + initializer(lecun_init(self.act_dim), [self.num_outer_channel, self.act_dim])) + self.right_projection_biases = Parameter(Tensor(np.zeros((self.num_outer_channel)), mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros((self.num_output_channel, self.num_outer_channel * self.num_outer_channel)), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros((self.num_output_channel)), mstype.float32)) + + def _compute(self, left_act, right_act, linear_output_weight, linear_output_bias, d, e): + '''compute outer product mean''' + + a, b, c = left_act.shape + left_act = P.Reshape()(P.Transpose()(left_act, (2, 1, 0)), (-1, a)) + act = P.Reshape()(self.trans(P.Reshape()(self.matmul(left_act, right_act), + (c, b, d, e)), (1, 2, 0, 3)), (b, d, c * e)) + # act_shape = P.Shape()(act) + # if len(act_shape) != 2: + # act = P.Reshape()(act, (-1, act_shape[-1])) + # act = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act, linear_output_weight), + # linear_output_bias), (d, b, -1)) + act = self.bias_add2(self.matmul_trans_b(act, linear_output_weight), + linear_output_bias) + # act = P.Transpose()(act, (1, 0, 2)) + return act \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24abb758d67f004381056a25968d5f7db8a9bb45 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Collective variables""" + +from .colvar import Colvar +from .base import Distance, Angle, Torsion +from .position import Atom, Position +from .atoms import AtomDistances, AtomAngles, AtomTorsions +from .bonded import BondedColvar, BondedDistances, BondedTorsions, BondedAngles +from .index import IndexColvar, IndexVectors, IndexDistances + +__all__ = ['Colvar', 'Distance', 'Angle', 'Torsion', 'Atom', 'Position', + 'AtomDistances', 'AtomAngles', 'AtomTorsions', 'BondedColvar', + 'BondedDistances', 'BondedTorsions', 'BondedAngles', 'IndexColvar', + 'IndexVectors', 'IndexDistances'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/atoms.py b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/atoms.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc980c5c87ccb6b5c027a1f79b5d6dedd99f3a8 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/atoms.py @@ -0,0 +1,226 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Collective variables for fixed atoms +""" + +import mindspore as ms +from mindspore import ops, nn +from mindspore import Tensor +from mindspore.ops import functional as F +from mindspore import numpy as msnp + +from ..function import functions as func +from ..function import get_ms_array +from .colvar import Colvar + + +class AtomDistances(Colvar): + r"""Distances of specific atoms + + Args: + index (int): Index of atoms. + + use_pbc (bool): Whether to use periodic boundary condition. Default: False + + length_unit (str) Length unit. Default: None + + """ + def __init__(self, + index: Tensor, + use_pbc: bool = None, + length_unit: str = None, + ): + + super().__init__( + dim_output=1, + periodic=False, + use_pbc=use_pbc, + length_unit=length_unit, + ) + + # (B,b,2) + self.index = get_ms_array(index, ms.int32) + if self.index.shape[-1] != 2: + raise ValueError('The last dimension of index in AtomDistances must be 2!') + self.dim_output = self.index.shape[-2] + self.identity = ops.Identity() + self.norm_last_dim = nn.Norm(axis=-1, keep_dims=False) + + def construct(self, coordinate: Tensor, pbc_box: Tensor = None): + r"""Compute distances. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Default: None + + Returns: + distances (Tensor): Tensor of shape (B, X, 1). Data type is float. + + """ + + # (B,b,2) + index = self.identity(self.index) + # (B,b,2,D) + atoms = func.gather_vectors(coordinate, index) + + # (B,b,D) + vec = self.get_vector(atoms[..., 0, :], atoms[..., 1, :], pbc_box) + # (B,b) + return self.norm_last_dim(vec) + + +class AtomAngles(Colvar): + r"""Angles of specific atoms + + Args: + index (int): Index of atoms. + use_pbc (bool): Whether to use periodic boundary condition. Default: False + + """ + def __init__(self, + index: Tensor, + use_pbc: bool = None, + ): + + super().__init__( + periodic=False, + use_pbc=use_pbc, + ) + + # (B,a,3) + self.index = get_ms_array(index, ms.int32) + if self.index.shape[-1] != 3: + raise ValueError('The last dimension of index in AtomAngles must be 3!') + self.dim_output = self.index.shape[-2] + self.split = ops.Split(-2, 3) + self.norm_last_dim = nn.Norm(axis=-1, keep_dims=False) + + def construct(self, coordinate: Tensor, pbc_box: Tensor = None): + r"""Compute angles. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Default: None + + Returns: + angles (Tensor): Tensor of shape (B, X, 1). Data type is float. + + """ + + # (B,a,3) + index = self.identity(self.index) + # (B,a,3,D) + atoms = func.gather_vectors(coordinate, index) + + # (B,a,1,D) + atom0, atom1, atom2 = self.split(atoms) + + vec1 = self.get_vector(atom1, atom0, pbc_box).squeeze(-2) + vec2 = self.get_vector(atom1, atom2, pbc_box).squeeze(-2) + + # (B,a) <- (B,a,D) + dis1 = self.norm_last_dim(vec1) + dis2 = self.norm_last_dim(vec2) + + # (B,a) <- (B,a,D) + vec1vec2 = F.reduce_sum(vec1*vec2, -1) + # (B,a) = (B,a) * (B,a) + dis1dis2 = dis1 * dis2 + # (B,a)/(B,a) + costheta = vec1vec2 * msnp.reciprocal(dis1dis2) + + # (B,a) + return F.acos(costheta) + + +class AtomTorsions(Colvar): + r"""Torsion (dihedral) angle of specific atoms + + Args: + index (int): Index of atoms. + use_pbc (bool): Whether to use periodic boundary condition. Default: False + + """ + def __init__(self, + index: Tensor, + use_pbc: bool = None, + ): + + super().__init__( + periodic=True, + use_pbc=use_pbc, + ) + + # (B,d,4) + self.index = get_ms_array(index, ms.int32) + if self.index.shape[-1] != 4: + raise ValueError('The last dimension of index in AtomTorsions must be 4!') + self.dim_output = self.index.shape[-2] + self.split = ops.Split(-2, 4) + self.keep_norm_last_dim = nn.Norm(axis=-1, keep_dims=True) + + def construct(self, coordinate: Tensor, pbc_box: Tensor = None): + r"""Compute torsions. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Default: None + + Returns: + torsion (Tensor): Tensor of shape (B, X, 1). Data type is float. + + """ + + # (B,d,4) + index = self.identity(self.index) + # (B,d,4,D) + atoms = func.gather_vectors(coordinate, index) + + # (B,d,1,D) + atom_a, atom_b, atom_c, atom_d = self.split(atoms) + + # (B,d,1,D) + vec_1 = self.get_vector(atom_b, atom_a, pbc_box).squeeze(-2) + vec_2 = self.get_vector(atom_c, atom_b, pbc_box).squeeze(-2) + vec_3 = self.get_vector(atom_d, atom_c, pbc_box).squeeze(-2) + + # (B,d,1) <- (B,M,D) + v2norm = self.keep_norm_last_dim(vec_2) + # (B,d,D) = (B,d,D) / (B,d,1) + norm_vec2 = vec_2 * msnp.reciprocal(v2norm) + + # (B,M,D) + vec_a = msnp.cross(norm_vec2, vec_1) + vec_b = msnp.cross(vec_3, norm_vec2) + cross_ab = msnp.cross(vec_a, vec_b) + + # (B,M) + sin_phi = F.reduce_sum(cross_ab*norm_vec2, -1) + cos_phi = F.reduce_sum(vec_a*vec_b, -1) + + # (B,M) + return F.atan2(-sin_phi, cos_phi) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/base.py b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7e499c63797ec39d5711a643d4d5d64472e2a515 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/base.py @@ -0,0 +1,177 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Collective variables by position +""" + +from mindspore import Tensor +from mindspore import nn + +from ..function import calc_angle_between_vectors, calc_torsion_for_vectors +from .colvar import Colvar +from .position import Position + + +class Distance(Colvar): + r"""Get distances by positions + + Args: + + position0 (Position): First position, + + position1 (Position): Second position, + + use_pbc (bool): Whether to calculate the CV at periodic boundary condition (PBC). + If "None" is given, it will be determined at runtime based on + whether the "pbc_box" is given or not. Default: None + + length_unit (str): Length unit for position coordinates. + If "None" is given, it will use the global units. Default: None + + """ + def __init__(self, + position0: Position, + position1: Position, + use_pbc: bool = None, + length_unit: str = None, + ): + + super().__init__( + dim_output=1, + periodic=False, + use_pbc=use_pbc, + length_unit=length_unit, + ) + + self.position0 = position0 + self.position1 = position1 + self.keep_norm_last_dim = nn.Norm(axis=-1, keep_dims=True) + + def construct(self, coordinate: Tensor, pbc_box: bool = None): + r"""Compute distance between two atoms. + + Args: + coordinate (ms.Tensor[B,N,D]) + + Returns: + distance (ms.Tensor[B,n,1]): + + """ + + pos0 = self.position0(coordinate) + pos1 = self.position1(coordinate) + + vec = self.get_vector(pos0, pos1, pbc_box) + return self.keep_norm_last_dim(vec) + + +class Angle(Colvar): + r"""Get angle by positions + + Args: + + """ + def __init__(self, + position_a: Position, + position_b: Position, + position_c: Position, + use_pbc: bool = None, + ): + + super().__init__( + dim_output=1, + periodic=False, + use_pbc=use_pbc, + ) + + self.position_a = position_a + self.position_b = position_b + self.position_c = position_c + + def construct(self, coordinate: Tensor, pbc_box: bool = None): + r"""Compute distance between two atoms. + + Args: + coordinate (ms.Tensor[B,N,D]) + + Returns: + distance (ms.Tensor[B,n,1]): + + """ + + pos_a = self.position_a(coordinate) + pos_b = self.position_b(coordinate) + pos_c = self.position_c(coordinate) + + vec_ba = self.get_vector(pos_b, pos_a, pbc_box) + vec_bc = self.get_vector(pos_b, pos_c, pbc_box) + + return calc_angle_between_vectors(vec_ba, vec_bc) + + +class Torsion(Colvar): + r"""Get torsion by positions + + Args: + + """ + def __init__(self, + position_a: Position, + position_b: Position, + position_c: Position, + position_d: Position, + use_pbc: bool = None, + ): + + super().__init__( + dim_output=1, + periodic=True, + use_pbc=use_pbc, + ) + + self.position_a = position_a + self.position_b = position_b + self.position_c = position_c + self.position_d = position_d + + def construct(self, coordinate: Tensor, pbc_box: bool = None): + r"""Compute distance between two atoms. + + Args: + coordinate (ms.Tensor[B,N,D]) + + Returns: + distance (ms.Tensor[B,n,1]): + + """ + + pos_a = self.position_a(coordinate) + pos_b = self.position_b(coordinate) + pos_c = self.position_c(coordinate) + pos_d = self.position_d(coordinate) + + vec_ba = self.get_vector(pos_b, pos_a, pbc_box) + vec_cb = self.get_vector(pos_c, pos_b, pbc_box) + vec_dc = self.get_vector(pos_d, pos_c, pbc_box) + + return calc_torsion_for_vectors(vec_ba, vec_cb, vec_dc) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/bonded.py b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/bonded.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b941b24275805ca6433b0c736ad64096638750 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/bonded.py @@ -0,0 +1,173 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Collective variables by bonds +""" + +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore import nn +from mindspore.ops import functional as F + +from ..function import functions as func +from .colvar import Colvar + + +class BondedColvar(Colvar): + r"""Get collective variables by bonds + + """ + + def __init__(self, + bond_index: int, + length_unit: str = None, + ): + + super().__init__( + dim_output=1, + periodic=False, + use_pbc=None, + length_unit=length_unit, + ) + + self.bond_index = bond_index + + def construct(self, bond_vectors: Tensor, bond_distances: Tensor): + #pylint: disable=arguments-differ + raise NotImplementedError + + +class BondedDistances(BondedColvar): + r"""Get distances by bonds + + """ + + def __init__(self, + bond_index: int = None, + length_unit: str = None, + ): + super().__init__( + bond_index=bond_index, + length_unit=length_unit, + ) + + def construct(self, bond_vectors: Tensor, bond_distances: Tensor): + r"""Compute distance between two atoms. + + Args: + coordinate (ms.Tensor[float]): coordinate of system with shape (B,A,D) + + Returns: + distances (ms.Tensor[float]): distance between atoms with shape (B,M,1) + + """ + + distances = bond_distances + if self.bond_index is not None: + distances = func.gather_values(bond_distances, self.bond_index) + + return distances + + +class BondedAngles(BondedColvar): + r"""Get angles by bonds + + """ + + def __init__(self, bond_index: int): + super().__init__( + bond_index=bond_index, + ) + + def construct(self, bond_vectors: Tensor, bond_distances: Tensor): + r"""Compute angles formed by three atoms. + + Args: + coordinate (ms.Tensor[float]): coordinate of system with shape (B,N,D) + + Returns: + angles (ms.Tensor[float]): angles of atoms with shape (B,n,1) + + """ + + # (B,a,2,D) <- gather (B,a,2) from (B,b,D) + vectors = func.gather_vectors(bond_vectors, self.bond_index) + # (B,a,2) <- gather (B,a,2) from (B,b) + distances = func.gather_values(bond_distances, self.bond_index) + + # (B,a) <- (B,a,D) + vec1vec2 = F.reduce_sum(vectors[:, :, 0, :]*vectors[:, :, 1, :], -1) + # (B,a) = (B,a) * (B,a) + dis1dis2 = distances[:, :, 0] * distances[:, :, 1] + # (B,a)/(B,a) + costheta = vec1vec2 * msnp.reciprocal(dis1dis2) + + # (B,a) + return F.acos(costheta) + + +class BondedTorsions(BondedColvar): + r"""Get torsion angles by bonds + + """ + + def __init__(self, bond_index: int): + super().__init__( + bond_index=bond_index, + ) + self.keep_norm_last_dim = nn.Norm(axis=-1, keep_dims=True) + + def construct(self, bond_vectors: Tensor, bond_distances: Tensor): + r"""Compute torision angles formed by four atoms. + + Args: + coordinate (ms.Tensor[float]): coordinate of system with shape (B,A,D) + + Returns: + angles (ms.Tensor[float]): (B,M,1) angles of atoms + + """ + + # (B,a,3,D) <- gather (B,a,3) from (B,b,D) + vectors = func.gather_vectors(bond_vectors, self.bond_index) + + vec_1 = vectors[:, :, 0, :] + vec_2 = vectors[:, :, 1, :] + vec_3 = vectors[:, :, 2, :] + + # (B,d,1) <- (B,M,D) + v2norm = self.keep_norm_last_dim(vec_2) + # (B,d,D) = (B,d,D) / (B,d,1) + norm_vec2 = vec_2 * msnp.reciprocal(v2norm) + + # (B,M,D) + vec_a = msnp.cross(norm_vec2, vec_1) + vec_b = msnp.cross(vec_3, norm_vec2) + cross_ab = msnp.cross(vec_a, vec_b) + + # (B,M) + sin_phi = F.reduce_sum(cross_ab*norm_vec2, -1) + cos_phi = F.reduce_sum(vec_a*vec_b, -1) + + # (B,M) + return F.atan2(-sin_phi, cos_phi) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/colvar.py b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/colvar.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7123f5fd4a56d123421ef0eabe52f0883bb1a9 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/colvar.py @@ -0,0 +1,113 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Collective variables +""" + +import mindspore as ms +from mindspore import ops +from mindspore.ops import functional as F +from mindspore.nn import Cell +from mindspore.common import Tensor + +from ..function import functions as func +from ..function.operations import GetVector +from ..function.units import Units, global_units + +class Colvar(Cell): + r"""Base class for collective variables. + + The function "construct" of Colvar must has the argument "coordinates" + + Args: + dim_output (int): The output dimension, i.e., the last dimension of output Tensor. + + periodic (bool): Whether the CV is periodic or not. Default: False + + use_pbc (bool): Whether to calculate the CV at periodic boundary condition (PBC). + If "None" is given, it will be determined at runtime based on + whether the "pbc_box" is given or not. Default: None + + length_unit (str): Length unit for position coordinates. + If "None" is given, it will use the global units. Default: None + + """ + + def __init__(self, + dim_output: int = 1, + periodic: bool = False, + use_pbc: bool = None, + length_unit: str = None, + ): + + super().__init__() + + self.dim_output = dim_output + + self.get_vector = GetVector(use_pbc) + self.use_pbc = use_pbc + + if length_unit is not None: + self.use_global_units = False + self.units = Units(length_unit) + else: + self.use_global_units = True + self.units = global_units + + # the CV is periodic or not + if isinstance(periodic, bool): + periodic = Tensor([periodic]*self.dim_output, ms.bool_) + elif isinstance(periodic, (list, tuple)): + if len(periodic) != self.dim_output: + if len(periodic) == 1: + periodic = Tensor(periodic*self.dim_output, ms.bool_) + else: + raise ValueError("The number of periodic mismatch") + else: + raise TypeError("Unsupported type for periodic:" + + str(type(periodic))) + + self.periodic = F.reshape(periodic, (1, 1, self.dim_output)) + + self.any_periodic = self.periodic.any() + self.all_periodic = self.periodic.all() + + self.identity = ops.Identity() + + @property + def length_unit(self): + """length unit""" + return self.units.length_unit + + def vector_in_box(self, vector: Tensor, pbc_box: Tensor) -> Tensor: + """Make the difference of vecters at the range from -0.5 box to 0.5 box""" + return func.vector_in_box(vector, pbc_box) + + def set_pbc(self, use_pbc: bool): + """set periodic boundary condition""" + self.use_pbc = use_pbc + self.get_vector.set_pbc(use_pbc) + return self + + def construct(self, coordinate, pbc_box=None): + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/index.py b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/index.py new file mode 100644 index 0000000000000000000000000000000000000000..a888eec376d0cdbc765897f5df466a15ad3b31f2 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/index.py @@ -0,0 +1,203 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Collective variables that accept index +""" + +import mindspore as ms +from mindspore.ops import functional as F +from mindspore import nn +from mindspore.common import Tensor +from mindspore import numpy as msnp + +from ..function import functions as func +from .colvar import Colvar + + +class IndexColvar(Colvar): + r"""Collective variables based on index + + Args: + dim_output (int): The output dimension, i.e., the last dimension of output Tensor. + + periodic (bool): Whether the CV is periodic or not. Default: False + + use_pbc (bool): Whether to calculate the CV at periodic boundary condition (PBC). + If "None" is given, it will be determined at runtime based on + whether the "pbc_box" is given or not. Default: None + + length_unit (str): Length unit for position coordinates. + If "None" is given, it will use the global units. Default: None + + """ + + def __init__(self, + dim_output: int, + periodic: bool = False, + use_pbc: bool = None, + length_unit: str = None, + ): + + super().__init__( + dim_output=dim_output, + periodic=periodic, + use_pbc=use_pbc, + length_unit=length_unit, + ) + + def construct(self, coordinate: Tensor, index: Tensor, mask: Tensor = None, pbc_box: Tensor = None): + #pylint: disable=arguments-differ + raise NotImplementedError + + +class IndexDistances(IndexColvar): + r"""Calculate distance between atoms by neighbour index + + Args: + use_pbc (bool): Whether to use periodic boundary condition. Default: False + + length_unit (str): Length unit. Default: None + + large_dis (float): A large value that added to the distance equal to zero to + prevent them from becoming zero values after Norm operation, + which could lead to auto-differentiation errors. + + keep_dims (bool): If this is "True", the last axis will be left in the result as + dimensions with size one. + + """ + + def __init__(self, + use_pbc: bool = None, + length_unit: str = None, + large_dis: float = 100, + keep_dims: bool = False, + ): + + super().__init__( + dim_output=1, + periodic=False, + use_pbc=use_pbc, + length_unit=length_unit, + ) + + self.norm_last_dim = nn.Norm(-1, keep_dims=keep_dims) + self.large_dis = Tensor(large_dis, ms.float32) + + def construct(self, coordinate: Tensor, index: Tensor, mask: Tensor = None, pbc_box: Tensor = None): + r"""Compute distances between atoms according to index. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Coordinate of system + index (Tensor): Tensor of shape (B, A, N). Data type is int. + Neighbour index + mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask of neighbour index + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Periodic boundary condition Box. + Default: None + + Returns: + distances (Tensor): Tensor of shape (B, A, N). Data type is float. + + Symbols: + + B: Batchsize, i.e. number of simulation walker. + A: Number of atoms. + N: Number of neighbour atoms. + D: Dimension of position coordinates. + + """ + + # (B,A,1,D) <- (B,A,D) + atoms = F.expand_dims(coordinate, -2) + # (B,A,N,D) <- (B,A,D) + neighbours = func.gather_vectors(coordinate, index) + vectors = self.get_vector(atoms, neighbours, pbc_box) + + # Add a non-zero value to the vectors whose mask value is False + # to prevent them from becoming zero values after Norm operation, + # which could lead to auto-differentiation errors + if mask is not None: + # (B,A,N,D) = (B,A,N,D) + (B,A,N,1) + vectors += F.expand_dims(msnp.where(mask, 0, self.large_dis), -1) + + # (B,A,N) = (B,A,N,D) + return self.norm_last_dim(vectors) + + +class IndexVectors(IndexColvar): + r"""Get vectors by index + + Args: + use_pbc (bool): Whether to use periodic boundary condition. Default: False + + length_unit (str): Length unit. Default: None + + """ + + def __init__(self, + use_pbc: bool = None, + length_unit: str = None, + ): + + super().__init__( + dim_output=1, + periodic=False, + use_pbc=use_pbc, + length_unit=length_unit, + ) + + def construct(self, coordinate: Tensor, index: Tensor, mask: Tensor = None, pbc_box: Tensor = None): + r"""get vector by index. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Coordinate of system + index (Tensor): Tensor of shape (B, A, N). Data type is int. + Neighbour index + mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask of neighbour index + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Periodic boundary condition Box. + Default: None + + Returns: + vector (Tensor): Tensor of shape (B, A, D). Data type is float. + + Symbols: + + B: Batchsize, i.e. number of simulation walker. + A: Number of atoms. + N: Number of neighbour atoms. + D: Dimension of position coordinates. + + """ + + # (B,A,1,D) <- (B,A,D) + atoms = F.expand_dims(coordinate, -2) + # (B,A,N,D) <- (B,A,D) + neighbours = func.gather_vectors(coordinate, index) + + return self.get_vector(atoms, neighbours, pbc_box) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/position.py b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/position.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3869addc47f2a0c4690d139009c8a6a64479c4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/colvar/position.py @@ -0,0 +1,68 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Position +""" + +import mindspore as ms +from mindspore.common import Tensor + +from .colvar import Colvar + + +class Position(Colvar): + r"""Position coordinate + + Args: + dim_output (str): Output dimension. Default: 3 + use_pbc (bool): Whether to use periodic boundary condition. Default: False + + """ + def __init__(self, + dim_output: int = 3, + use_pbc: bool = None + ): + + super().__init__( + dim_output=dim_output, + periodic=False, + use_pbc=use_pbc + ) + + def construct(self, coordinate, pbc_box=None): + raise NotImplementedError + + +class Atom(Position): + r"""Atom position + + Args: + index (int): index of atoms + + """ + def __init__(self, index: int): + super().__init__() + self.index = Tensor(index, ms.int32) + + def construct(self, coordinate, pbc_box=None): + return coordinate[..., self.index, :] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/common/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbbed7321d31bf477ac14b6a9125df577303de0d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/common/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" + +from .geometry import vecs_scale, rots_scale, vecs_sub, vecs_robust_norm, vecs_robust_normalize +from .geometry import vecs_cross_vecs, rots_from_two_vecs, rigids_from_3_points, invert_rots +from .geometry import vecs_dot_vecs, rots_mul_vecs, invert_rigids, rigids_mul_vecs, rigids_mul_rots +from .geometry import rigids_mul_rigids, rots_mul_rots, vecs_from_tensor, vecs_to_tensor +from .geometry import make_transform_from_reference, rots_from_tensor, rots_to_tensor +from .geometry import quat_affine, quat_to_rot, initial_affine, vecs_expand_dims +from .geometry import rots_expand_dims, invert_point, quat_multiply_by_vec, quaternion_to_tensor +from .geometry import quaternion_from_tensor, apply_to_point, pre_compose +from .utils import get_pdb_info, make_atom14_positions, get_fasta_info, get_aligned_seq, find_optimal_renaming +__all__ = ["get_pdb_info", "make_atom14_positions", "get_fasta_info", "get_aligned_seq", + "vecs_scale", "rots_scale", "vecs_sub", "vecs_robust_norm", "vecs_robust_normalize", + "vecs_cross_vecs", "rots_from_two_vecs", "rigids_from_3_points", "invert_rots", + "vecs_dot_vecs", "rots_mul_vecs", "invert_rigids", "rigids_mul_vecs", "rigids_mul_rots", + "rigids_mul_rigids", "rots_mul_rots", "vecs_from_tensor", "vecs_to_tensor", + "make_transform_from_reference", "rots_from_tensor", "rots_to_tensor", + "quat_affine", "quat_to_rot", "initial_affine", "vecs_expand_dims", + "rots_expand_dims", "invert_point", "quat_multiply_by_vec", "quaternion_to_tensor", + "quaternion_from_tensor", "apply_to_point", "pre_compose", "find_optimal_renaming"] + +__all__.sort() diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/common/config_load.py b/MindSPONGE/applications/research/Grasp/mindsponge1/common/config_load.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9132cad4deb4f0637c6f3984fe741cdff3df71 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/common/config_load.py @@ -0,0 +1,43 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""config load""" +from pprint import pformat +import yaml + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + +def load_config(path): + """ + Convert yaml file to Obj. + """ + f = open(path, 'r') + config = yaml.load(f, Loader=yaml.FullLoader) + config = Config(config) + return config diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/common/geometry.py b/MindSPONGE/applications/research/Grasp/mindsponge1/common/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7da2718a5bfaed0f48fc15b70b9145d2b48644 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/common/geometry.py @@ -0,0 +1,1467 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Geometry""" +import numpy as np +import mindspore.numpy as mnp +from mindspore import Tensor +from mindspore.ops import operations as P + +QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32) +QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, -1]] + +QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 1], + [0, 0, -1, 0]] + +QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], + [0, 0, 0, -1], + [1, 0, 0, 0], + [0, 1, 0, 0]] + +QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], + [0, 0, 1, 0], + [0, -1, 0, 0], + [1, 0, 0, 0]] + +QUAT_MULTIPLY_BY_VEC = Tensor(QUAT_MULTIPLY[:, 1:, :]) + +QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) + +QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr +QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii +QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj +QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk + +QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij +QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik +QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk + +QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir +QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr +QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr + +QUAT_TO_ROT = Tensor(QUAT_TO_ROT) + + +def vecs_scale(v, scale): + r""" + Scale the vector. + + .. math:: + \begin{split} + &v=(x1,x2,x3) \\ + &scaled\_{vecs} = (scale*x1,scale*x2,scale*x3) \\ + \end{split} + + Args: + v(Tuple): Vector will be scaled, :math:`(x,y,z)`. x, y, z are scalars or Tensor with same shape. + scale(float): Value of scale. + + Returns: + Tuple with length of 3, vector after scaled with the same shape as input v. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindspore import dtype as mstype + >>> from mindsponge.common.geometry import vecs_scale + >>> x= Tensor(np.ones(256), mstype.float32) + >>> y= Tensor(np.ones(256), mstype.float32) + >>> z= Tensor(np.ones(256), mstype.float32) + >>> scale=10 + >>> result=vecs_scale((x,y,z),scale) + >>> print(len(result)) + >>> print(result[0].shape) + >>> print(result[1].shape) + >>> print(result[2].shape) + 3 + (256,) + (256,) + (256,) + """ + scaled_vecs = (v[0] * scale, v[1] * scale, v[2] * scale) + return scaled_vecs + + +def rots_scale(rot, scale): + r""" + Scaling of rotation matrixs. + + .. math:: + \begin{split} + &rot=(xx,xy,xz,yx,yy,yz,zx,zy,zz) \\ + &scaled\_{rots} = (scale*xx,scale*xy,scale*xz,scale*yx,scale*yy,scale*yz,scale*zx,scale*zy,scale*zz) + \end{split} + + Args: + rot(Tuple): Rots, length is 9, :math:`(xx,xy,xz,yx,yy,yz,zx,zy,zz)` . Data type is scalar or + Tensor with the same shape. + scale(float): Value of scale. + + Returns: + Tuple, scaled rotation matrixs. Length is 9, shape is the same as the input rots' shape. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindspore import dtype as mstype + >>> from mindsponge.common.geometry import rots_scale + >>> x = Tensor(np.ones(256), mstype.float32) + >>> result = rots_scale((x, x, x, x, x, x, x, x, x),10) + >>> print(len(result)) + >>> print(result[0].shape) + >>> print(result[1].shape) + >>> print(result[2].shape) + >>> print(result[3].shape) + >>> print(result[4].shape) + >>> print(result[5].shape) + >>> print(result[6].shape) + >>> print(result[7].shape) + >>> print(result[8].shape) + 3 + (256,) + (256,) + (256,) + (256,) + (256,) + (256,) + (256,) + (256,) + (256,) + """ + scaled_rots = (rot[0] * scale, rot[1] * scale, rot[2] * scale, + rot[3] * scale, rot[4] * scale, rot[5] * scale, + rot[6] * scale, rot[7] * scale, rot[8] * scale) + return scaled_rots + + +def vecs_sub(v1, v2): + r""" + Subtract two vectors. + + .. math:: + \begin{split} + &v1=(x1,x2,x3) \\ + &v2=(x1',x2',x3') \\ + &result=(x1-x1',x2-x2',x3-x3') \\ + \end{split} + + Args: + v1(Tuple): input vector 1 :math:`(x, y, z)`, data type is scalar or Tensor with same shape. + v2(Tuple): input vector 2 :math:`(x, y, z)`, data type is scalar or Tensor with same shape. + + Returns: + Tuple. Length is 3, :math:`(x', y', z')` , data type is scalar or Tensor with same shape as v1. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindspore import dtype as mstype + >>> from mindsponge.common.geometry import vecs_sub + >>> x= Tensor(np.ones(256), mstype.float32) + >>> y= Tensor(np.ones(256), mstype.float32) + >>> z= Tensor(np.ones(256), mstype.float32) + >>> result=vecs_sub((x,y,z),(x,y,z)) + >>> print(len(result)) + >>> print(result[0].shape) + >>> print(result[1].shape) + >>> print(result[2].shape) + 3 + (256,) + (256,) + (256,) + """ + return (v1[0] - v2[0], v1[1] - v2[1], v1[2] - v2[2]) + + +def vecs_robust_norm(v, epsilon=1e-8): + r""" + Calculate the l2-norm of a vector. + + .. math:: + \begin{split} + &v=(x1,x2,x3) \\ + &l2\_norm=\sqrt{x1*x1+x2*x2+x3*x3+epsilon} \\ + \end{split} + + Args: + v(Tuple): Input vector :math:`(x,y,z)` . Data type is scalar or Tensor with same shape. + epsilon(float): A very small number to prevent the result from being 0. Default: 1e-8. + + Returns: + Tensor, 2-Norm calculated by vector v. Shape is the same as v. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindspore import dtype as mstype + >>> from mindsponge.common.geometry import vecs_robust_norm + >>> x= Tensor(np.ones(256), mstype.float32) + >>> y= Tensor(np.ones(256), mstype.float32) + >>> z= Tensor(np.ones(256), mstype.float32) + >>> result=vecs_robust_norm((x,y,z)) + >>> print(result.shape) + (256) + """ + v_l2_norm = v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + epsilon + v_norm = v_l2_norm ** 0.5 + return v_norm + + +def vecs_robust_normalize(v, epsilon=1e-8): + r""" + Use l2-norm normalization vectors + + .. math:: + \begin{split} + &v=(x1,x2,x3) \\ + &l2\_norm=\sqrt{x1*x1+x2*x2+x3*x3+epsilon} \\ + &result=(x1/l2\_norm, x2/l2\_norm, x3/l2\_norm) \\ + \end{split} + + Args: + v(Tuple): Input vector :math:`(x,y,z)` . Data type is scalar or Tensor with same shape. + epsilon(float): Minimal value, prevent the result from being 0. Default: 1e-8. + + Returns: + Tuple with length of 3, normalized 2-Norm calculated by vector v. Shape is the same as v. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindspore import dtype as mstype + >>> from mindsponge.common.geometry import vecs_robust_normalize + >>> x= Tensor(np.ones(256), mstype.float32) + >>> y= Tensor(np.ones(256), mstype.float32) + >>> z= Tensor(np.ones(256), mstype.float32) + >>> result=vecs_robust_normalize((x,y,z)) + >>> print(len(result)) + >>> print(result[0].shape) + >>> print(result[1].shape) + >>> print(result[2].shape) + 3 + (256,) + (256,) + (256,) + """ + norms = vecs_robust_norm(v, epsilon) + return (v[0] / norms, v[1] / norms, v[2] / norms) + + +def vecs_dot_vecs(v1, v2): + r""" + Dot product of vectors :math:`v_1 = (x_1, x_2, x_3)` and :math:`v_2 = (y_1, y_2, y_3)`. + + .. math:: + res = x_1 * y_1 + x_2 * y_2 + x_3 * y_3 + + Args: + v1 (tuple): vectors :math:`\vec v_1` , length is 3. + Data type is constant or Tensor with same shape. + v2 (tuple): vectors :math:`\vec v_2` , length is 3. + Data type is constant or Tensor with same shape. + + Returns: + float or Tensor with the same shape as the Tensor in input, dot product result of two vectors . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> v1 = (1, 2, 3) + >>> v2 = (3, 4, 5) + >>> ans = mindsponge.common.vecs_dot_vecs(v1, v2) + >>> print(ans) + 26 + """ + res = v1[0] * v2[0] + v1[1] * v2[1] + v1[2] * v2[2] + return res + + +def vecs_cross_vecs(v1, v2): + r""" + Cross product of vectors :math:`v_1 = (x_1, x_2, x_3)` and :math:`v_2 = (y_1, y_2, y_3)`. + + .. math:: + cross_{res} = (x_2 * y_3 - x_3 * y_2, x_3 * y_1 - x_1 * y_3, x_1 * y_2 - x_2 * y_1) + + Args: + v1 (tuple): vectors :math:`\vec v_1` , length is 3. + Data type is constant or Tensor with same shape. + v2 (tuple): vectors :math:`\vec v_2` , length is 3. + Data type is constant or Tensor with same shape. + + Returns: + tuple, cross product result of two vectors, length is 3. + Data type is constant or Tensor with same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> v1 = (1, 2, 3) + >>> v2 = (3, 4, 5) + >>> ans = mindsponge.common.vecs_cross_vecs(v1, v2) + >>> print(ans) + (2, -4, 2) + """ + cross_res = (v1[1] * v2[2] - v1[2] * v2[1], + v1[2] * v2[0] - v1[0] * v2[2], + v1[0] * v2[1] - v1[1] * v2[0]) + return cross_res + + +def rots_from_two_vecs(e0_unnormalized, e1_unnormalized): + r""" + Put in two vectors :math:`\vec a = (a_x, a_y, a_z)` and :math:`\vec b = (b_x, b_y, b_z)`. + Calculate the rotation matrix between local coordinate system, in which the x-y plane + consists of two input vectors and global coordinate system. + + Calculate the unit vector :math:`\vec e_0 = \frac{\vec a}{|\vec a|}` + as the unit vector of x axis. + + Then calculate the projected length of :math:`\vec b` on a axis. + :math:`c = |\vec b| \cos\theta = \vec b \cdot \frac{\vec a}{|\vec a|}` . + + So the projected vector of :math:`b` on a axis is :math:`c\vec e_0`. + The vector perpendicular to e0 is :math:`\vec e_1' = \vec b - c\vec e_0` . + + The unit vector of :math:`\vec e_1'` is :math:`\vec e_1 = \frac{\vec e_1'}{|\vec e_1'|}`, + which is the y axis of the local coordinate system. + + Finally get the unit vector of z axis :math:`\vec e_2` by calculating cross product of + :math:`\vec e_1` and :math:`\vec e_0`. + + Args: + e0_unnormalized (tuple): vectors :math:`\vec a` as x-axis of x-y plane, + length is 3. Data type is constant or Tensor with same shape. + e1_unnormalized (tuple): vectors :math:`\vec b` forming x-y plane, + length is 3. Data type is constant or Tensor with same shape. + + Returns: + tuple, rotation matrix :math:`(e_0_x, e_1_x, e_2_x, e_0_y, e_1_y, e_2_y, e_0_z, e_1_z, e_2_z)` . + Data type is constant or Tensor with same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> v1 = (1, 2, 3) + >>> v2 = (3, 4, 5) + >>> ans = mindsponge.common.rots_from_two_vecs(v1, v2) + >>> print(ans) + (0.4242640686695021, -0.808290367995452, 0.40824828617045156, 0.5656854248926695, + -0.1154700520346678, -0.8164965723409039, 0.7071067811158369, 0.5773502639261153, + 0.4082482861704521) + """ + + # Normalize the unit vector for the x-axis, e0. + e0 = vecs_robust_normalize(e0_unnormalized) + + # make e1 perpendicular to e0. + c = vecs_dot_vecs(e1_unnormalized, e0) + e1 = vecs_sub(e1_unnormalized, vecs_scale(e0, c)) + e1 = vecs_robust_normalize(e1) + + # Compute e2 as cross product of e0 and e1. + e2 = vecs_cross_vecs(e0, e1) + rots = (e0[0], e1[0], e2[0], + e0[1], e1[1], e2[1], + e0[2], e1[2], e2[2]) + return rots + + +def rigids_from_3_points(point_on_neg_x_axis, origin, point_on_xy_plane): + r""" + Gram-Schmidt process. Create rigids representation of 3 points local coordination system, + point on negative x axis A, origin point O and point on x-y plane P. + + First calculate the coordinations of vector :math:`\vec AO` and :math:`\vec OP`. Then + use `rots_from_two_vecs` get the rotation matrix. + + Distance between origin point O and the origin point of global coordinate system is + the translations of rigid. + + Finally return the rotations and translations of rigid. + + Reference: + `Jumper et al. (2021) Suppl. Alg. 21 'Gram-Schmidt process' + `_. + + .. math:: + \begin{split} + &\vec v_1 = \vec x_3 - \vec x_2 \\ + &\vec v_2 = \vec x_1 - \vec x_2 \\ + &\vec e_1 = \vec v_1 / ||\vec v_1|| \\ + &\vec u_2 = \vec v_2 - \vec e_1(\vec e_1^T\vec v_2) \\ + &\vec e_2 = \vec u_2 / ||\vec u_2|| \\ + &\vec e_3 = \vec e_1 \times \vec e_2 \\ + &rotation = (\vec e_1, \vec e_2, \vec e_3) \\ + &translation = (\vec x_2) \\ + \end{split} + + Args: + point_on_neg_x_axis (tuple): point on negative x axis A, length is 3. + Data type is constant or Tensor with same shape. + origin (tuple): origin point O, length is 3. + Data type is constant or Tensor with same shape. + point_on_xy_plane (tuple): point on x-y plane P, length is 3. + Data type is constant or Tensor with same shape. + + Returns: + tuple(rots, trans), rigid, length is 2. Include rots :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` + and trans :math:`(x, y, z)` . Data type is constant or Tensor with same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> A = (1, 2, 3) + >>> O = (4, 6, 8) + >>> P = (5, 8, 11) + >>> ans = mindsponge.common.rigids_from_3_points(A, O, P) + >>> print(ans) + ((0.4242640686695021, -0.808290367995452, 0.40824828617045156, 0.5656854248926695, + -0.1154700520346678, -0.8164965723409039, 0.7071067811158369, 0.5773502639261153, + 0.4082482861704521), (4,6,8)) + """ + m = rots_from_two_vecs( + e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis), + e1_unnormalized=vecs_sub(point_on_xy_plane, origin)) + rigid = (m, origin) + return rigid + + +def invert_rots(m): + r""" + Computes inverse of rotations :math:`m`. + + rotations :math:`m = (xx, xy, xz, yx, yy, yz, zx, zy, zz)` and + inverse of :math:`m` is :math:`m^{T} = (xx, yx, zx, xy, yy, zy, xz, yz, zz)` . + + Args: + m (tuple): rotations :math:`m` , length is 9. + Data type is constant or Tensor with same shape. + + Returns: + tuple, inverse of rotations :math:`m` , length is 9. Data type is constant or Tensor with same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> m = (1, 2, 3, 4, 5, 6, 7, 8, 9) + >>> inv_m = mindsponge.common.invert_rots(m) + >>> print(inv_m) + (1, 4, 7, 2, 5, 8, 3, 6, 9) + """ + invert = (m[0], m[3], m[6], + m[1], m[4], m[7], + m[2], m[5], m[8]) + return invert + + +def rots_mul_vecs(m, v): + r""" + Apply rotations :math:`\vec m = (m_0, m_1, m_2, m_3, m_4, m_5, m_6, m_7, m_8)` + to vectors :math:`\vec v = (v_0, v_1, v_2)`. + + .. math:: + out = m \cdot v^T = (m_0 \times v_0 + m_1 \times v_1 + m_2 \times v_2, + m_3 \times v_0 + m_4 \times v_1 + m_5 \times v_2, + m_6 \times v_0 + m_7 \times v_1 + m_8 \times v_2) + + Args: + m (tuple): rotations :math:`\vec m` , length is 9. + Data type is constant or Tensor with same shape. + v (tuple): vectors :math:`\vec v` , length is 3. + Data type is constant or Tensor with same shape. + + Returns: + tuple, vectors after rotations. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> m = (1, 2, 3, 4, 5, 6, 7, 8, 9) + >>> v = (1, 2, 3) + >>> v1 = mindsponge.common.rots_mul_vecs(m, v) + >>> print(v1) + (14, 32, 50) + """ + out = (m[0] * v[0] + m[1] * v[1] + m[2] * v[2], + m[3] * v[0] + m[4] * v[1] + m[5] * v[2], + m[6] * v[0] + m[7] * v[1] + m[8] * v[2]) + return out + + +def invert_rigids(rigids): + r""" + Computes group inverse of rigid transformations. Change rigid from + local coordinate system to global coordinate system. + + Use `invert_rots` to calculate the invert rotations of rigid. Then use + `rots_mul_vecs` to rotate the translations of rigid. The opposite of the + result is the translations of invert rigid. + + .. math:: + inv\_rots = r_r^T = (r_0, r_3, r_6, r_1, r_4, r_7, r_2, r_5, r_8) + + inv\_trans = -r_r^T \cdot r_t^T = (- (r_0 \times t_0 + r_3 \times t_0 + r_6 \times t_0), + - (r_1 \times t_1 + r_4 \times t_1 + r_7 \times t_1), + - (r_2 \times t_2 + r_5 \times t_2 + r_8 \times t_2)) + + Args: + rigids (tuple): rigids, including the rots and trans changing rigids + from global coordinate system to local coordinate system. + + Returns: + tuple(rots, trans), group inverse of rigid transformations, length is 2. Include rots + :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` and trans :math:`(x, y, z)` . + Data type is constant or Tensor with same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> a = ((1, 2, 3, 4, 5, 6, 7, 8, 9), (3, 4, 5)) + >>> inv_a = mindsponge.common.invert_rigids(a) + >>> print(inv_a) + ((1, 4, 7, 2, 5, 8, 3, 6, 9), (-54.0, -66.0, -78.0)) + """ + rot, trans = rigids + inv_rots = invert_rots(rot) + t = rots_mul_vecs(inv_rots, trans) + inv_trans = (-1.0 * t[0], -1.0 * t[1], -1.0 * t[2]) + inv_rigids = (inv_rots, inv_trans) + return inv_rigids + + +def vecs_add(v1, v2): + """Add two vectors 'v1' and 'v2'.""" + return (v1[0] + v2[0], v1[1] + v2[1], v1[2] + v2[2]) + + +def rigids_mul_vecs(rigids, v): + r""" + Transform vector :math:`v` to rigid' local coordinate system. + + Multiply vector :math:`v` and the rotations of rigid together + and add the translations of rigid. The result is the output vector. + + .. math:: + v = r_rv+r_t + + Args: + rigids (tuple): rigid. + v (tuple): vector :math:`\vec v` , length is 3. Data type is constant or Tensor with same shape. + + Returns: + tuple, changed vector, length is 3. Data type is constant or Tensor with same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> a = ((1, 2, 3, 4, 5, 6, 7, 8, 9), (3, 4, 5)) + >>> b = (1, 2, 3) + >>> b1 = mindsponge.common.rigids_mul_vecs(a,b) + >>> print(b1) + (17, 36, 55) + """ + return vecs_add(rots_mul_vecs(rigids[0], v), rigids[1]) + + +def rigids_mul_rots(x, y): + r""" + Numpy version of getting results rigid :math:`x` multiply rotations :math:`\vec y` . + + Multiply rotations of rigid :math:`x` with rotations :math:`y`, + the result is rigids new rotations. Translations of rigid will not changed. + + .. math:: + (r, t) = (x_ry, x_t) + + Args: + x (tuple): rigid :math:`x` . Length is 2. Include rots :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` + and trans :math:`(x, y, z)` . Data type is constant or Tensor with same shape. + y (tuple): rotations :math:`\vec y` , length is 9. Data type is constant or Tensor with same shape. + + Returns: + tuple(rots, trans), length is 2, rigid whose rotations are changed. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> a = ((1, 2, 3, 4, 5, 6, 7, 8, 9), (3, 4, 5)) + >>> b = (2, 3, 4, 1, 5, 6, 3, 8, 7) + >>> b1 = mindsponge.common.rigids_mul_rots(a,b) + >>> print(b1) + ((13, 37, 37, 31, 85, 88, 49, 133, 139), (3, 4, 5)) + """ + rigids = (rots_mul_rots(x[0], y), x[1]) + return rigids + + +def rigids_mul_rigids(a, b): + r""" + Change rigid :math:`b` from its local coordinate system to rigid :math:`a` + local coordinate system, using rigid :math:`a` rotations and translations. + + Use the rotations calculated by multiplying rotations of rigid :math:`b` + and rigid :math:`a` as new rotations of rigid :math:`b` . + + Multiply the translations of rigid :math:`b` with rotations of rigid :math:`a` , + then add translations of rigid :math:`a` . The translations got is new translations + of rigid :math:`b`. + + .. math:: + \begin{split} + &r = a_rb_r \\ + &t = a_rb_t +a_t \\ + \end{split} + + Args: + a (tuple): rigid :math:`a` . Length is 2. Include rots :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` + and trans :math:`(x, y, z)` . Data type is constant or Tensor with same shape. + b (tuple): rigid :math:`b` . Length is 2. Include rots :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` + and trans :math:`(x, y, z)` . Data type is constant or Tensor with same shape. + + Returns: + tuple(rots, trans), rigid :math:`b` changed. Length is 2. + Include rots :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` + and trans :math:`(x, y, z)` . Data type is constant or Tensor with same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> a = ((1, 2, 3, 4, 5, 6, 7, 8, 9), (3, 4, 5)) + >>> b = ((2, 3, 4, 1, 5, 6, 3, 8, 7), (1, 2, 3)) + >>> b1 = mindsponge.common.rigids_mul_rigids(a,b) + >>> print(b1) + ((13, 37, 37, 31, 85, 88, 49, 133, 139), (17, 36, 55)) + """ + rot = rots_mul_rots(a[0], b[0]) + trans = vecs_add(a[1], rots_mul_vecs(a[0], b[1])) + return (rot, trans) + + +def rots_mul_rots(x, y): + r""" + Get result of rotation matrix x multiply rotation matrix y. + + .. math:: + \begin{split} + &xx = xx1*xx2 + xy1*yx2 + xz1*zx2 \\ + &xy = xx1*xy2 + xy1*yy2 + xz1*zy2 \\ + &xz = xx1*xz2 + xy1*yz2 + xz1*zz2 \\ + &yx = yx1*xx2 + yy1*yx2 + yz1*zx2 \\ + &yy = yx1*xy2 + yy1*yy2 + yz1*zy2 \\ + &yz = yx1*xz2 + yy1*yz2 + yz1*zz2 \\ + &zx = zx1*xx2 + zy1*yx2 + zz1*zx2 \\ + &zy = zx1*xy2 + zy1*yy2 + zz1*zy2 \\ + &zz = zx1*xz2 + zy1*yz2 + zz1*zz2 \\ + \end{split} + + Args: + x(tuple): rots x, :math:`(xx1, xy1, xz1, yx1, yy1, yz1, zx1, zy1, zz1)`. + y(tuple): rots y, :math:`(xx2, xy2, xz2, yx2, yy2, yz2, zx2, zy2, zz2)`. + + Returns: + tuple, the result of rots x multiplying rots y. Shape is :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindsponge.common.geometry import rots_mul_rots + >>> rtos_0 = (1, 1, 1, 1, 1, 1, 1) + >>> rtos_1 = (1, 1, 1, 1, 1, 1, 1) + >>> result = rots_mul_rots(rots_0, rots_1) + >>> print(output) + (3, 3, 3, 3, 3, 3, 3, 3, 3) + """ + vecs0 = rots_mul_vecs(x, (y[0], y[3], y[6])) + vecs1 = rots_mul_vecs(x, (y[1], y[4], y[7])) + vecs2 = rots_mul_vecs(x, (y[2], y[5], y[8])) + rots = (vecs0[0], vecs1[0], vecs2[0], vecs0[1], vecs1[1], vecs2[1], vecs0[2], vecs1[2], vecs2[2]) + return rots + + +def vecs_from_tensor(inputs): + """ + Get vectors from the last axis of input tensor. + + Args: + inputs(Tensor): Atom position information. Shape is :math:`(..., 3)`. + + Returns: + tuple :math:`(x, y, z)` , including the coordinate information of x, y and z. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindsponge.common.geometry import vecs_from_tensor + >>> input_0 = Tensor(np.ones((4, 256, 3)), ms.float32) + >>> output = vecs_from_tensor(input_0) + >>> print(len(output), output[0].shape) + 3, (4,256) + """ + num_components = inputs.shape[-1] + assert num_components == 3 + return (inputs[..., 0], inputs[..., 1], inputs[..., 2]) + + +def vecs_to_tensor(v): + """ + Converts 'v' to tensor with last dim shape 3, inverse of 'vecs_from_tensor'. + + Args: + v(tuple): Input tuple v :math:`(x, y, z)`, including the coordinate information of x, y and z. + + Returns: + tensor, concat the tensor in last dims, shape :math:`(..., 3)` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindsponge.common.geometry import vecs_to_tensor + >>> input_0 = Tensor(np.ones((4, 256)), ms.float32) + >>> input_1 = Tensor(np.ones((4, 256)), ms.float32) + >>> input_2 = Tensor(np.ones((4, 256)), ms.float32) + >>> inputs = (input_0, input_1, input_2) + >>> output = vecs_to_tensor(inputs) + >>> print(output.shape) + (4, 256, 3) + """ + return mnp.stack([v[0], v[1], v[2]], axis=-1) + + +def make_transform_from_reference(point_a, point_b, point_c): + r""" + Using GramSchmidt process to construct rotation and translation from given points. + + Calculate the rotation matrix and translation meets + + a) point_b is the original point. + + b) point_c is on the x_axis. + + c) the plane a-b-c is on the x-y plane. + + .. math:: + \begin{split} + &\vec v_1 = \vec x_3 - \vec x_2 \\ + &\vec v_2 = \vec x_1 - \vec x_2 \\ + &\vec e_1 = \vec v_1 / ||\vec v_1|| \\ + &\vec u_2 = \vec v_2 - \vec e_1(\vec e_1^T\vec v_2) \\ + &\vec e_2 = \vec u_2 / ||\vec u_2|| \\ + &\vec e_3 = \vec e_1 \times \vec e_2 \\ + &rotation = (\vec e_1, \vec e_2, \vec e_3) \\ + &translation = (\vec x_2) \\ + \end{split} + + Args: + point_a(float, tensor) -> (tensor): Spatial location information of atom 'N', + shape is :math:`[..., N_{res}, 3]` . + point_b(float, tensor) -> (tensor): Spatial location information of atom 'CA', + shape is :math:`[..., N_{res}, 3]` . + point_c(float, tensor) -> (tensor): Spatial location information of atom 'C', + shape is :math:`[..., N_{res}, 3]` . + + Returns: + - Tuple, rots :math:`[xx, xy, xz, yx, yy, yz, zx, zy, zz]` , + the shape of every element is :math:`(..., N_{res})` . + - Tuple, trans :math:`[x, y, z]` , the shape of every element is :math:`(..., N_{res})` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindsponge.common.geometry import make_transform_from_reference + >>> input_0 = Tensor(np.ones((4, 256, 3)), ms.float32) + >>> input_1 = Tensor(np.ones((4, 256, 3)), ms.float32) + >>> input_2 = Tensor(np.ones((4, 256, 3)), ms.float32) + >>> rots, trans = make_transform_from_reference(input_0, input_1, input_2) + >>> print(len(rots), rots[0].shape, len(trans), trans[0].shape) + 9, (4, 256), 3, (4, 256) + """ + + # step 1 : shift the crd system by -point_b (point_b is the origin) + translation = -point_b + point_c = point_c + translation + point_a = point_a + translation + # step 2: rotate the crd system around z-axis to put point_c on x-z plane + c_x, c_y, c_z = vecs_from_tensor(point_c) + sin_c1 = -c_y / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2) + cos_c1 = c_x / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2) + zeros = mnp.zeros_like(sin_c1) + ones = mnp.ones_like(sin_c1) + c1_rot_matrix = (cos_c1, -sin_c1, zeros, + sin_c1, cos_c1, zeros, + zeros, zeros, ones) + # step 2 : rotate the crd system around y_axis to put point_c on x-axis + sin_c2 = c_z / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2) + cos_c2 = mnp.sqrt(c_x ** 2 + c_y ** 2) / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2) + c2_rot_matrix = (cos_c2, zeros, sin_c2, + zeros, ones, zeros, + -sin_c2, zeros, cos_c2) + c_rot_matrix = rots_mul_rots(c2_rot_matrix, c1_rot_matrix) + # step 3: rotate the crd system in y-z plane to put point_a in x-y plane + vec_a = vecs_from_tensor(point_a) + _, rotated_a_y, rotated_a_z = rots_mul_vecs(c_rot_matrix, vec_a) + + sin_n = -rotated_a_z / mnp.sqrt(1e-20 + rotated_a_y ** 2 + rotated_a_z ** 2) + cos_n = rotated_a_y / mnp.sqrt(1e-20 + rotated_a_y ** 2 + rotated_a_z ** 2) + a_rot_matrix = (ones, zeros, zeros, + zeros, cos_n, -sin_n, + zeros, sin_n, cos_n) + rotation_matrix = rots_mul_rots(a_rot_matrix, c_rot_matrix) + translation = point_b + translation = vecs_from_tensor(translation) + return rotation_matrix, translation + + +def rots_from_tensor(rots, use_numpy=False): + """ + Amortize and split the 3*3 rotation matrix corresponding to the last two axes of input Tensor + to obtain each component of the rotation matrix, inverse of 'rots_to_tensor'. + + Args: + rots(Tensor): Represent the rotation matrix, shape is :math:`(..., 3, 3)` . + use_numpy(bool): Whether to use numpy to calculate. Default: False. + + Returns: + Tuple, rots represented by vectors, shape is :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindsponge.common.geometry import rots_from_tensor + >>> input_0 = Tensor(np.ones((256, 3, 3)), ms.float32) + >>> output = rots_from_tensor(input_0) + >>> print(len(output), output[0].shape) + 9, (256,) + """ + if use_numpy: + rots = np.reshape(rots, rots.shape[:-2] + (9,)) + else: + rots = P.Reshape()(rots, P.Shape()(rots)[:-2] + (9,)) + rotation = (rots[..., 0], rots[..., 1], rots[..., 2], + rots[..., 3], rots[..., 4], rots[..., 5], + rots[..., 6], rots[..., 7], rots[..., 8]) + return rotation + + +def rots_to_tensor(rots, use_numpy=False): + """ + Translate rots represented by vectors to tensor, inverse of 'rots_from_tensor'. + + Args: + rots(Tuple): Rots represented by vectors, shape is :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` . + use_numpy(bool): Whether to use numpy to calculate. Default: False. + + Returns: + Tensor, concat the tensor in last dims, shape :math:`(N_{res}, 3, 3)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindsponge.common.geometry import rots_to_tensor + >>> inputs = [Tensor(np.ones((256,)), ms.float32) for i in range(9)] + >>> output = rots_to_tensor(inputs) + >>> print(output.shape) + (256, 3, 3) + """ + assert len(rots) == 9 + if use_numpy: + rots = np.stack(rots, axis=-1) + rots = np.reshape(rots, rots.shape[:-1] + (3, 3)) + else: + rots = mnp.stack(rots, axis=-1) + rots = mnp.reshape(rots, rots.shape[:-1] + (3, 3)) + return rots + + +def quat_affine(quaternion, translation, rotation=None, normalize=True, unstack_inputs=False, use_numpy=False): + """ + Create quat affine representations based on rots and trans. + + Args: + quaternion(tensor): Shape is :math:`(N_{res}, 4)`. + translation(tensor): Shape is :math:`(N_{res}, 3)`. + rotation(tensor): Rots, shape is :math:`(N_{res}, 9)`. Default: None. + normalize(bool): Whether to use normalization. Default: True. + unstack_inputs(bool): Whether input is vector(True) of Tensor(False). Default: False. + use_numpy(bool): Whether to use numpy. Default: False. + + Returns: + result after quat affine. + - quaternion, tensor, shape is :math:`(N_{res}, 4)` . + - rotation, tuple, :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, shape of every element is :math:`(N_{res},)` . + - translation, tensor, shape is :math:`(N_{res}, 3)` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindsponge.common.geometry import quat_affine + >>> input_0 = Tensor(np.ones((256, 4)), ms.float32) + >>> input_1 = Tensor(np.ones((256, 3)), ms.float32) + >>> qua, rot, trans = quat_affine(input_0, input_1) + >>> print(qua.shape, len(rot), rot[0].shape, trans.shape) + (256, 4), 9, (256,), (256, 3) + """ + if unstack_inputs: + if rotation is not None: + rotation = rots_from_tensor(rotation, use_numpy) + translation = vecs_from_tensor(translation) + + if normalize and quaternion is not None: + quaternion = quaternion / mnp.norm(quaternion, axis=-1, keepdims=True) + if rotation is None: + rotation = quat_to_rot(quaternion) + return quaternion, rotation, translation + + +def quat_to_rot(normalized_quat, use_numpy=False): + r""" + Convert a normalized quaternion to a rotation matrix. + + .. math:: + \begin{split} + &xx = 1 - 2 * y * y - 2 * z * z \\ + &xy = 2 * x * y + 2 * w * z \\ + &xz = 2 * x * z - 2 * w * y \\ + &yx = 2 * x * y - 2 * w * z \\ + &yy = 1 - 2 * x * x - 2 * z * z \\ + &yz = 2 * z * y + 2 * w * x \\ + &zx = 2 * x * z + 2 * w * y \\ + &zy = 2 * y * z - 2 * w * x \\ + &zz = 1 - 2 * x * x - 2 * y * y \\ + \end{split} + + Args: + normalized_quat (tensor): normalized quaternion, shape :math:`(N_{res}, 4)`. + use_numpy (bool): use numpy or not, Default: "False". + + Returns: + tuple, rotation :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, every element shape :math:`(N_{res},)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindsponge.common.geometry import quat_to_rot + >>> input_0 = Tensor(np.ones((256, 4)), ms.float32) + >>> output = quat_to_rot(input_0) + >>> print(len(output), output[0].shape) + 9, (256,) + """ + if use_numpy: + rot_tensor = np.sum(np.reshape(QUAT_TO_ROT.asnumpy(), (4, 4, 9)) * normalized_quat[..., :, None, None] \ + * normalized_quat[..., None, :, None], axis=(-3, -2)) + rot_tensor = rots_from_tensor(rot_tensor, use_numpy) + else: + rot_tensor = mnp.sum(mnp.reshape(QUAT_TO_ROT, (4, 4, 9)) * normalized_quat[..., :, None, None] * + normalized_quat[..., None, :, None], axis=(-3, -2)) + rot_tensor = P.Split(-1, 9)(rot_tensor) + rot_tensor = (P.Squeeze()(rot_tensor[0]), P.Squeeze()(rot_tensor[1]), P.Squeeze()(rot_tensor[2]), + P.Squeeze()(rot_tensor[3]), P.Squeeze()(rot_tensor[4]), P.Squeeze()(rot_tensor[5]), + P.Squeeze()(rot_tensor[6]), P.Squeeze()(rot_tensor[7]), P.Squeeze()(rot_tensor[8])) + return rot_tensor + + +def initial_affine(num_residues, use_numpy=False): + """ + Initialize quaternion, rotation, translation of affine. + + Args: + num_residues(int): Number of residues. + use_numpy(bool): Whether to use numpy. Default: False. + + Returns: + result after quat affine. + - quaternion, tensor, shape is :math:`(N_{res}, 4)` . + - rotation, tuple, :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, shape of every element is :math:`(N_{res},)` . + - translation, tensor, shape is :math:`(N_{res}, 3)` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindsponge.common.geometry import initial_affine + >>> output = initial_affine(256) + >>> print(len(output), output[0].shape, len(output[1]), len(output[1][0]), len(output[2]), len(output[2][0])) + >>> print(output[0]) + >>> print(output[1]) + >>> print(output[2]) + 3, (1, 4), 9, 1, 3, 1 + [[1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]] + (1, 0, 0, 0, 1, 0, 0, 0, 1) + ([0.00000000e+00], [0.00000000e+00], [0.00000000e+00]) + """ + if use_numpy: + quaternion = np.tile(np.reshape(np.asarray([1., 0., 0., 0.]), [1, 4]), [num_residues, 1]) + translation = np.zeros([num_residues, 3]) + else: + quaternion = mnp.tile(mnp.reshape(mnp.asarray([1., 0., 0., 0.]), [1, 4]), [num_residues, 1]) + translation = mnp.zeros([num_residues, 3]) + return quat_affine(quaternion, translation, unstack_inputs=True, use_numpy=use_numpy) + + +def vecs_expand_dims(v, axis): + r""" + Add an extra dimension to the input `v` at the given axis. + + Args: + v(Tuple): Input vector. Length is 3, :math:`(xx, xy, xz)` . + axis(int): Specifies the dimension index at which to expand the shape of `v`. Only constant value is allowed. + + Returns: + Tuple, if the axis is 0, and the shape of :math:`xx` is :math:`(... , X_R)`, where X_R is any number. + If the axis is other value, then expand in the other direction. And return expanded + :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindsponge.common.geometry import vecs_expand_dims + >>> from mindspore.common import Tensor + >>> from mindspore import dtype as mstype + >>> v = (1, 2, 3) + >>> axis = 0 + >>> output= vecs_expand_dims(v, axis) + >>> print(output) + (Tensor(shape=[1], dtype=Int64, value=[1]),Tensor(shape=[1], dtype=Int64, value=[2]), + Tensor(shape=[1], dtype=Int64, value=[3])) + """ + v = (P.ExpandDims()(v[0], axis), P.ExpandDims()(v[1], axis), P.ExpandDims()(v[2], axis)) + return v + + +def rots_expand_dims(rots, axis): + """ + Adds an additional dimension to `rots` at the given axis. + + Args: + rots (Tuple): The rotation matrix is :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, + and xx and xy have the same shape + axis (Int): Specifies the dimension index at which to expand the shape of v. + Only constant value is allowed. + + Returns: + Tuple, rots. If the value of axis is 0, and the shape of xx is :math:`(... ,X_R)`, + where X_R is any number, and the expanded shape is :math:`(1,... ,X_R)`. + Return expanded :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindsponge.common.geometry import rots_expand_dims + >>> from mindspore.common import Tensor + >>> from mindspore import dtype as mstype + >>> rots = (1, 2, 3, 4, 5, 6, 7, 8, 9) + >>> axis = 0 + >>> rots_expand_dims(rots, axis) + >>> print(output) + (Tensor(shape=[1], dtype=Int64, value=[1]), Tensor(shape=[1], dtype=Int64, value=[2]), + Tensor(shape=[1], dtype=Int64, value=[3]), Tensor(shape=[1], dtype=Int64, value=[4]), + Tensor(shape=[1], dtype=Int64, value=[5]), Tensor(shape=[1], dtype=Int64, value=[6]), + Tensor(shape=[1], dtype=Int64, value=[7]), Tensor(shape=[1], dtype=Int64, value=[8]), + Tensor(shape=[1], dtype=Int64, value=[9])) + """ + rots = (P.ExpandDims()(rots[0], axis), P.ExpandDims()(rots[1], axis), P.ExpandDims()(rots[2], axis), + P.ExpandDims()(rots[3], axis), P.ExpandDims()(rots[4], axis), P.ExpandDims()(rots[5], axis), + P.ExpandDims()(rots[6], axis), P.ExpandDims()(rots[7], axis), P.ExpandDims()(rots[8], axis)) + return rots + + +def invert_point(transformed_point, rotation, translation, extra_dims=0, stack=False, use_numpy=False): + r""" + The inverse transformation of a rigid body group transformation with respect to a point coordinate, + that is, the inverse transformation of apply to point Make rotational translation changes on coordinates + with the transpose of the rotation + matrix :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)` and the translation vector :math:`(x, y, z)` translation. + + First, the initial coordinates are translated, and then the transpose of the rotation matrix is multiplied + by rot_point to get the final coordinates. + + .. math:: + \begin{split} + &rot_point = transformed_point - translation \\ + &result = rotation^t * rot_point \\ + \end{split} + + The specific procedures of vector subtraction, transpose and multiplication can be referred to the + api of vecs_sub, invert_rots, rots_mul_vecs etc. + + Args: + transformed_point (Tuple): The initial coordinates of the input have shape :math:`(x, y, z)`, + where x, y and z are Tensor and have the same shape. + rotation (Tuple): The rotation matrix. shape is :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, + and xx and xy have the same shape. + translation (Tuple): The translation vector shape is :math:`(x, y, z)`, + where x, y and z are Tensor and have the same shape. + extra_dims (int): Control whether to expand dims. Default: 0. + stack (bool): Control whether to transform to tuple. Default: False. + use_numpy(bool): Control whether to use numpy. Default: False. + + Returns: + Tuple, the transformed coordinate of invert point.Length is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindsponge.common.geometry import invert_point + >>> from mindspore.common import Tensor + >>> from mindspore import dtype as mstype + >>> transformed_point = (1, 2, 3) + >>> rotation = (1, 2, 3, 4, 5, 6, 7, 8, 9) + >>> translation = (1, 0.5, -1) + >>> output= invert_point(transformed_point, rotation, translation) + >>> print(output) + (Tensor(shape=[], dtype=Float32, value = 34), Tensor(shape=[], dtype=Float32, value = 39.5), + Tensor(shape=[], dtype=Float32, value = 45)) + """ + if stack: + rotation = rots_from_tensor(rotation, use_numpy) + translation = vecs_from_tensor(translation) + for _ in range(extra_dims): + rotation = rots_expand_dims(rotation, -1) + translation = vecs_expand_dims(translation, -1) + rot_point = vecs_sub(transformed_point, translation) + return rots_mul_vecs(invert_rots(rotation), rot_point) + + +def quat_multiply_by_vec(quat, vec): + r""" + Multiply a quaternion by a pure-vector quaternion. + + .. math:: + \begin{split} + &temp = QUAT_MULTIPLY_BY_VEC * quat[..., :, None, None] * vec[..., None, :, None] \\ + &result = sum(tempc,axis=(-3, -2)) \\ + \end{split} + + Args: + quat (Tensor): Quaternion.Tensor of shape :math:`(..., 4)`. + vec (Tensor): A pure-vector quaternion, :math:`(b, c, d)` not normalized quaternion. + Quaternion can be expressed as :math:`(1, b, c, d)`. + + Returns: + Tensor, the product of a quaternion with a pure vector quaternion. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.common.geometry import quat_multiply_by_vec + >>> from mindspore.common import Tensor + >>> from mindspore import dtype as mstype + >>> np.random.seed(1) + >>> quat = Tensor(np.random.rand(4),dtype=mstype.float32) + >>> vec = Tensor(np.random.rand(3),dtype=mstype.float32) + >>> out = quat_multiply_by_vec(quat, vec) + >>> print(out) + [-0.16203496, 0.03330477, -0.05129148, 0.14417158] + """ + + return mnp.sum(QUAT_MULTIPLY_BY_VEC * quat[..., :, None, None] * vec[..., None, :, None], + axis=(-3, -2)) + + +def pre_compose(quaternion, rotation, translation, update): + r""" + Return a new QuatAffine which applies the transformation update first. + + The process of obtaining the updated translation vector and rotation matrix is as follows: + + .. math:: + \begin{split} + &update = (xx, xy, xz, yx, yy, yz) \\ + &vector_quaternion_update = (xx, xy, xz) \\ + &x = (yx) \\ + &y = (yy) \\ + &z = (yz) \\ + &trans_update = (x, y, z) \\ + &new_quaternion = quaternion + vector_quaternion_update * quaternion \\ + &rotated_trans_update = rotation * trans_update \\ + &new_translation = translation + rotated_trans_update \\ + \end{split} + + vector_quaternion_update and quaternion are multiplied by the quat_multiply_by_vec function, + Affine transformation is performed using the generated new_quaternion and new_translation. + The process of affine transformation is referred to the quat_affine api. + + Args: + quaternion (Tensor): The initial quaternion to be updated, shape :math:`[(..., 4)]`. + rotation (Tuple): Rotation matrix, :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, + and xx and xy are Tensor and have the same shape. + translation (Tuple): Translation vector :math:`(x, y, z)`, + where x, y and z are Tensor and have the same shape. + update (Tensor): The update-assisted matrix has shape :math:`[(..., 6)]`. + 3-vector of x, y, and z such that the quaternion + update is :math:`(1, x, y, z)` and zero for the 3-vector is the identity + quaternion. 3-vector for translation concatenated. + + Returns: + - Tensor, new quaternion.The updated Tensor tuple has shape :math:`[(..., 4)]`. + - Tuple, the updated rotation matrix :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, + and xx and xy are Tensor and have the same shape. + - Tuple, the updated translation vector :math:`(x, y, z)`, + where x, y and z are Tensor and have the same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.common.geometry import pre_compose + >>> from mindspore.common import Tensor + >>> from mindspore import dtype as mstype + >>> np.random.seed(1) + >>> quaternion = Tensor(np.random.rand(4),dtype=mstype.float32) + >>> update = Tensor(np.random.rand(6),dtype=mstype.float32) + >>> rotation = Tensor(np.random.rand(9),dtype=mstype.float32) + >>> translation = Tensor(np.random.rand(3),dtype=mstype.float32) + >>> quaternion, rotation, translation = pre_compose(quaternion,rotation,translation,update) + >>> print(quaternion) + [ 0.27905196 0.82475466 -0.05600705 0.48864394] + >>> print(rotation) + (Tensor(shape=[], dtype=Float32, value= 0.516181), Tensor(shape=[], dtype=Float32, value= -0.365098), + Tensor(shape=[], dtype=Float32, value= 0.774765), Tensor(shape=[], dtype=Float32, value= 0.18033), + Tensor(shape=[], dtype=Float32, value= -0.837986), Tensor(shape=[], dtype=Float32, value= -0.515034), + Tensor(shape=[], dtype=Float32, value= 0.837281), Tensor(shape=[], dtype=Float32, value= 0.405564), + Tensor(shape=[], dtype=Float32, value= -0.366714)) + >>> print(translation) + (Tensor(shape=[], dtype=Float32, value= 0.724994), Tensor(shape=[], dtype=Float32, value= 1.47631), + Tensor(shape=[], dtype=Float32, value= 1.40978)) + """ + + vector_quaternion_update, x, y, z = mnp.split(update, [3, 4, 5], axis=-1) + trans_update = [mnp.squeeze(x, axis=-1), mnp.squeeze(y, axis=-1), mnp.squeeze(z, axis=-1)] + new_quaternion = (quaternion + quat_multiply_by_vec(quaternion, vector_quaternion_update)) + rotated_trans_update = rots_mul_vecs(rotation, trans_update) + new_translation = vecs_add(translation, rotated_trans_update) + return quat_affine(new_quaternion, new_translation) + + +def quaternion_to_tensor(quaternion, translation): + r""" + Change quaternion to tensor. + + .. math:: + \begin{split} + &quaternion = [(x_1, y_1, z_1, m_1)] \\ + &translation = [(x_2, y_2, z_2)] \\ + &result = [(x_1, y_1, z_1, m_1, x_2, y_2, z_2)] \\ + \end{split} + + Args: + quaternion (Tensor): Inputs quaternion. Tensor of shape :math:`(..., 4)`. + translation (Tensor): Inputs translation. Tensor of shape :math:`(..., 3)` + + Returns: + Tensor, The result of the concatenation between translation and translation. Tensor of shape :math:`(..., 7)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.common.geometry import quaternion_to_tensor + >>> from mindspore.common import Tensor + >>> from mindspore import dtype as mstype + >>> np.random.seed(1) + >>> quaternion = Tensor(np.random.rand(4),dtype=mstype.float32) + >>> translation = Tensor(np.random.rand(3),dtype=mstype.float32) + >>> out = quaternion_to_tensor(quaternion, translation) + >>> print(out) + [0.6631489 0.44137922 0.97213906 0.7425225 0.3549025 0.6535310.5426164 ] + """ + translation = (P.ExpandDims()(translation[0], -1), P.ExpandDims()(translation[1], -1), + P.ExpandDims()(translation[2], -1),) + return mnp.concatenate((quaternion,) + translation, axis=-1) + + +def quaternion_from_tensor(tensor, normalize=False): + r""" + Take the input 'tensor' to get the new 'quaternion', 'rotation', 'translation'. + + .. math:: + \begin{split} + &quaternion = [(x_1, y_1, z_1, m_1)] \\ + &translation = [(x_2, y_2, z_2)] \\ + &result = [(x_1, y_1, z_1, m_1, x_2, y_2, z_2)] \\ + \end{split} + + Affine transformation is performed using the generated quaternion and translation. + The process of affine transformation is referred to the quat_affine api. + + Args: + tensor(Tensor): An initial Tensor of shape is :math:`[(... 7)]`. + normalize(bool): Control whether to find the norm during quat_affine. Default: False. + + Returns: + - Tensor, new quaternion.Tensor of shape :math:`(..., 4)` . + - Tuple, new rotation, :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, + and xx and xy are Tensor and have the same shape. + - Tuple, translation vector :math:`[(x, y, z)]`, where x, y and z are Tensor and have the same shape. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.common.geometry import quaternion_from_tensor + >>> from mindspore.common import Tensor + >>> tensor = Tensor(np.random.rand(7),dtype=mstype.float32) + >>> quaternion, rotation, translation = quaternion_from_tensor(tensor) + >>> print(quaternion) + [4.17021990e-01, 7.20324516e-01, 1.14374816e-04, 3.02332580e-01] + >>> print(rotation) + (Tensor(shape=[], dtype=Float32, value= 0.60137), Tensor(shape=[], dtype=Float32, value= -0.251994), + Tensor(shape=[], dtype=Float32, value= 0.435651), Tensor(shape=[], dtype=Float32, value= 0.252323), + Tensor(shape=[], dtype=Float32, value= -0.436365), Tensor(shape=[], dtype=Float32, value= -0.600713), + Tensor(shape=[], dtype=Float32, value= 0.43546), Tensor(shape=[], dtype=Float32, value= 0.600851), + Tensor(shape=[], dtype=Float32, value= -0.253555)) + >>> print(translation) + (Tensor(shape=[], dtype=Float32, value= 0.146756),Tensor(shape=[], dtype=Float32, value= 0.0923386), + Tensor(shape=[], dtype=Float32, value= 0.18626)) + """ + quaternion, tx, ty, tz = mnp.split(tensor, [4, 5, 6], axis=-1) + translation = (P.Squeeze()(tx), P.Squeeze()(ty), P.Squeeze()(tz)) + return quat_affine(quaternion, translation, normalize=normalize) + + +def apply_to_point(rotation, translation, point, extra_dims=0): + r""" + Rotate and translate the input coordinates. + + .. math:: + \begin{split} + &rot_point = rotation * point \\ + &result = rot_point + translation \\ + \end{split} + + For specific multiplication and addition procedures, refer to the rots_mul_vecs and vecs_add apis. + + Args: + rotation(Tuple): The rotation matrix :math:`(xx, xy, xz, yx, yy, yz, zx, zy, zz)`, + and xx and xy are Tensor and have the same shape. + translation(Tuple): Translation vector :math:`[(x, y, z)]`, + where x, y and z are Tensor and have the same shape. + point(Tensor): Initial coordinate values :math:`[(x, y, z)]`, + where x, y and z are Tensor and have the same shape. + extra_dims(int): Control whether to expand dims. default:0. + + Returns: + Tuple, the result of the coordinate transformation. Length is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.common.geometry import apply_to_point + >>> from mindspore.common import Tensor + >>> from mindspore import dtype as mstype + >>> np.random.seed(1) + >>> rotation = [] + >>> for i in range(9): + ... rotation.append(Tensor(np.random.rand(4),dtype=mstype.float32)) + >>> translation = [] + >>> for i in range(3): + ... translation.append(Tensor(np.random.rand(4),dtype=mstype.float32)) + >>> point = [] + >>> for i in range(3): + ... point.append(Tensor(np.random.rand(4),dtype=mstype.float32)) + >>> out = apply_to_point(rotation, translation, point) + >>> print(out) + (Tensor(shape=[4], dtype=Float32, value= [ 1.02389336e+00, 1.12493467e+00, 2.54357845e-01, 1.25249946e+00]), + Tensor(shape=[4], dtype=Float32, value= [ 9.84841168e-01, 5.20081401e-01, 6.43978953e-01, 6.15328550e-01]), + Tensor(shape=[4], dtype=Float32, value= [ 8.62860143e-01, 9.11733627e-01, 1.09284782e+00, 1.44202101e+00])) + """ + for _ in range(extra_dims): + rotation = rots_expand_dims(rotation, -1) + translation = vecs_expand_dims(translation, -1) + rot_point = rots_mul_vecs(rotation, point) + result = vecs_add(rot_point, translation) + return result diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/common/protein.py b/MindSPONGE/applications/research/Grasp/mindsponge1/common/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..5874d6b59e0a135de5e9deaf30ea66522b79c9e7 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/common/protein.py @@ -0,0 +1,300 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""tein""" +import io +from typing import Any, Mapping, Optional +import dataclasses + +from Bio.PDB import PDBParser +import numpy as np + +from . import residue_constants + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If None, then the pdb file must contain a single chain (which + will be parsed). If chain_id is specified (e.g. A), then only that chain + is parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser() + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.') + model = models[0] + + if chain_id is not None: + chain = model[chain_id] + else: + chains = list(model.get_chains()) + if len(chains) != 1: + raise ValueError( + 'Only single chain PDBs are supported when chain_id not specified. ' + f'Found {len(chains)} chains.') + chain = chains[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + b_factors = [] + + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + #print(res_shortname, restype_idx) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + b_factors.append(res_b_factors) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + b_factors=np.array(b_factors)) + +def from_pdb_string_all_chains(pdb_str: str) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser() + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.') + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + b_factors = [] + + for chain in model: + last_res_idx = 0 + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + #print(res_shortname, restype_idx) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + if res.id[1] != last_res_idx + 1: + # If there is a gap in the residue index, then add a placeholder + # residue with all-zero atom positions and mask. + atom_positions.extend([np.zeros((residue_constants.atom_type_num, 3))]*(res.id[1] - last_res_idx - 1)) + atom_mask.extend([np.zeros((residue_constants.atom_type_num,))]*(res.id[1] - last_res_idx - 1)) + residue_index.extend([0]*(res.id[1] - last_res_idx - 1)) + b_factors.extend([np.zeros((residue_constants.atom_type_num,))]*(res.id[1] - last_res_idx - 1)) + aatype.extend([residue_constants.restype_num]*(res.id[1] - last_res_idx - 1)) + last_res_idx = res.id[1] + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + b_factors.append(res_b_factors) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + b_factors=np.array(b_factors)) + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + pdb_lines.append('MODEL 1') + atom_index = 1 + chain_id = 'A' + # Add all atom sites. + for i in range(aatype.shape[0]): + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_id:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the chain. + chain_end = 'TER' + chain_termination_line = ( + f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} ' + f'{chain_id:>1}{residue_index[-1]:>4}') + pdb_lines.append(chain_termination_line) + pdb_lines.append('ENDMDL') + + pdb_lines.append('END') + pdb_lines.append('') + return '\n'.join(pdb_lines) + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction(final_atom_positions, final_atom_mask, aatype, residue_index, b_factors=None) -> Protein: + """Assembles a protein from a prediction. + + Args: + final_atom_positions: atom positions + final_atom_mask: atom mask + aatype: amino acid type + residue_index: idx of the residue + Returns: + A protein instance. + """ + if b_factors is None: + b_factors = np.zeros_like(final_atom_mask) + + return Protein( + aatype=aatype, + atom_positions=final_atom_positions, + atom_mask=final_atom_mask, + residue_index=residue_index + 1, + b_factors=b_factors) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/common/residue_constants.py b/MindSPONGE/applications/research/Grasp/mindsponge1/common/residue_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..01291cdef937f10c069bef8a28ae4d238f6323c3 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/common/residue_constants.py @@ -0,0 +1,923 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""residue_constants.""" +import os +import collections +import functools +from typing import Mapping +import numpy as np + +import mindsponge +from mindspore.common.tensor import Tensor + +stereo_chemical_props_path = os.path.dirname(mindsponge.__file__) + "/common/stereo_chemical_props.txt" +QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32) +QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, -1]] + +QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 1], + [0, 0, -1, 0]] + +QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], + [0, 0, 0, -1], + [1, 0, 0, 0], + [0, 1, 0, 0]] + +QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], + [0, 0, 1, 0], + [0, -1, 0, 0], + [1, 0, 0, 0]] + +QUAT_MULTIPLY_BY_VEC = Tensor(QUAT_MULTIPLY[:, 1:, :]) + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']], + 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE']], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', + 'CH2', 'N', 'NE1', 'O'], + 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', + 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'] +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': {'OD1': 'OD2'}, + 'GLU': {'OE1': 'OE2'}, + 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'}, + 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple( + 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props(): + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: dict that maps resname --> list of Bond tuples + residue_virtual_bonds: dict that maps resname --> list of Bond tuples + residue_bond_angles: dict that maps resname --> list of BondAngle tuples + """ + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle(atom1, atom2, atom3, + float(angle_degree) / 180. * np.pi, + float(stddev_degree) / 180. * np.pi)) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * + np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * + np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * + np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return residue_bonds, residue_virtual_bonds, residue_bond_angles + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + 'CZ3', 'NZ', 'OXT' +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], + 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], + 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], + 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], + 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], + +} + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V' +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} +order_restype_with_x = {i: restype for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot( + sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + 'The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % + sorted( + mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError( + f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + +MSA_GAP_IDX = restypes_with_x_and_gap.index('-') +MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX, + 'msa_mask_all_seq': 1, + 'deletion_matrix_all_seq': 0, + 'deletion_matrix_int_all_seq': 0, + 'msa': MSA_GAP_IDX, + 'msa_mask': 1, + 'deletion_matrix': 0, + 'deletion_matrix_int': 0} + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3.get(restype_letter) + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3.get(r) + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, + eznorm, translation]).transpose() + m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3.get(restype_letter) + for atomname, group_idx, atom_position in rigid_group_atom_positions.get(resname): + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, + atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names.get(resname).index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, + atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = {name: np.array(pos) for name, _, pos + in rigid_group_atom_positions[resname]} + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1., 0., 0.]), + translation=atom_positions['N']) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C']) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] + for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2]) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1., 0., 0.]), + translation=axis_end_atom_position) + restype_rigid_group_default_frame[restype, + 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3.get(restype_letter) + atom_list = restype_name_to_atom14_names.get(resname) + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, + atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, + atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, + atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, + atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, + atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, + atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, + atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, + atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, + atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, + atom2_idx, atom1_idx] = b.stddev + return restype_atom14_bond_lower_bound, restype_atom14_bond_upper_bound, restype_atom14_bond_stddev diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/common/stereo_chemical_props.txt b/MindSPONGE/applications/research/Grasp/mindsponge1/common/stereo_chemical_props.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ead07a39e297b910c09ea543538109bfa55eecd --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/common/stereo_chemical_props.txt @@ -0,0 +1,345 @@ +Bond Residue Mean StdDev +CA-CB ALA 1.520 0.021 +N-CA ALA 1.459 0.020 +CA-C ALA 1.525 0.026 +C-O ALA 1.229 0.019 +CA-CB ARG 1.535 0.022 +CB-CG ARG 1.521 0.027 +CG-CD ARG 1.515 0.025 +CD-NE ARG 1.460 0.017 +NE-CZ ARG 1.326 0.013 +CZ-NH1 ARG 1.326 0.013 +CZ-NH2 ARG 1.326 0.013 +N-CA ARG 1.459 0.020 +CA-C ARG 1.525 0.026 +C-O ARG 1.229 0.019 +CA-CB ASN 1.527 0.026 +CB-CG ASN 1.506 0.023 +CG-OD1 ASN 1.235 0.022 +CG-ND2 ASN 1.324 0.025 +N-CA ASN 1.459 0.020 +CA-C ASN 1.525 0.026 +C-O ASN 1.229 0.019 +CA-CB ASP 1.535 0.022 +CB-CG ASP 1.513 0.021 +CG-OD1 ASP 1.249 0.023 +CG-OD2 ASP 1.249 0.023 +N-CA ASP 1.459 0.020 +CA-C ASP 1.525 0.026 +C-O ASP 1.229 0.019 +CA-CB CYS 1.526 0.013 +CB-SG CYS 1.812 0.016 +N-CA CYS 1.459 0.020 +CA-C CYS 1.525 0.026 +C-O CYS 1.229 0.019 +CA-CB GLU 1.535 0.022 +CB-CG GLU 1.517 0.019 +CG-CD GLU 1.515 0.015 +CD-OE1 GLU 1.252 0.011 +CD-OE2 GLU 1.252 0.011 +N-CA GLU 1.459 0.020 +CA-C GLU 1.525 0.026 +C-O GLU 1.229 0.019 +CA-CB GLN 1.535 0.022 +CB-CG GLN 1.521 0.027 +CG-CD GLN 1.506 0.023 +CD-OE1 GLN 1.235 0.022 +CD-NE2 GLN 1.324 0.025 +N-CA GLN 1.459 0.020 +CA-C GLN 1.525 0.026 +C-O GLN 1.229 0.019 +N-CA GLY 1.456 0.015 +CA-C GLY 1.514 0.016 +C-O GLY 1.232 0.016 +CA-CB HIS 1.535 0.022 +CB-CG HIS 1.492 0.016 +CG-ND1 HIS 1.369 0.015 +CG-CD2 HIS 1.353 0.017 +ND1-CE1 HIS 1.343 0.025 +CD2-NE2 HIS 1.415 0.021 +CE1-NE2 HIS 1.322 0.023 +N-CA HIS 1.459 0.020 +CA-C HIS 1.525 0.026 +C-O HIS 1.229 0.019 +CA-CB ILE 1.544 0.023 +CB-CG1 ILE 1.536 0.028 +CB-CG2 ILE 1.524 0.031 +CG1-CD1 ILE 1.500 0.069 +N-CA ILE 1.459 0.020 +CA-C ILE 1.525 0.026 +C-O ILE 1.229 0.019 +CA-CB LEU 1.533 0.023 +CB-CG LEU 1.521 0.029 +CG-CD1 LEU 1.514 0.037 +CG-CD2 LEU 1.514 0.037 +N-CA LEU 1.459 0.020 +CA-C LEU 1.525 0.026 +C-O LEU 1.229 0.019 +CA-CB LYS 1.535 0.022 +CB-CG LYS 1.521 0.027 +CG-CD LYS 1.520 0.034 +CD-CE LYS 1.508 0.025 +CE-NZ LYS 1.486 0.025 +N-CA LYS 1.459 0.020 +CA-C LYS 1.525 0.026 +C-O LYS 1.229 0.019 +CA-CB MET 1.535 0.022 +CB-CG MET 1.509 0.032 +CG-SD MET 1.807 0.026 +SD-CE MET 1.774 0.056 +N-CA MET 1.459 0.020 +CA-C MET 1.525 0.026 +C-O MET 1.229 0.019 +CA-CB PHE 1.535 0.022 +CB-CG PHE 1.509 0.017 +CG-CD1 PHE 1.383 0.015 +CG-CD2 PHE 1.383 0.015 +CD1-CE1 PHE 1.388 0.020 +CD2-CE2 PHE 1.388 0.020 +CE1-CZ PHE 1.369 0.019 +CE2-CZ PHE 1.369 0.019 +N-CA PHE 1.459 0.020 +CA-C PHE 1.525 0.026 +C-O PHE 1.229 0.019 +CA-CB PRO 1.531 0.020 +CB-CG PRO 1.495 0.050 +CG-CD PRO 1.502 0.033 +CD-N PRO 1.474 0.014 +N-CA PRO 1.468 0.017 +CA-C PRO 1.524 0.020 +C-O PRO 1.228 0.020 +CA-CB SER 1.525 0.015 +CB-OG SER 1.418 0.013 +N-CA SER 1.459 0.020 +CA-C SER 1.525 0.026 +C-O SER 1.229 0.019 +CA-CB THR 1.529 0.026 +CB-OG1 THR 1.428 0.020 +CB-CG2 THR 1.519 0.033 +N-CA THR 1.459 0.020 +CA-C THR 1.525 0.026 +C-O THR 1.229 0.019 +CA-CB TRP 1.535 0.022 +CB-CG TRP 1.498 0.018 +CG-CD1 TRP 1.363 0.014 +CG-CD2 TRP 1.432 0.017 +CD1-NE1 TRP 1.375 0.017 +NE1-CE2 TRP 1.371 0.013 +CD2-CE2 TRP 1.409 0.012 +CD2-CE3 TRP 1.399 0.015 +CE2-CZ2 TRP 1.393 0.017 +CE3-CZ3 TRP 1.380 0.017 +CZ2-CH2 TRP 1.369 0.019 +CZ3-CH2 TRP 1.396 0.016 +N-CA TRP 1.459 0.020 +CA-C TRP 1.525 0.026 +C-O TRP 1.229 0.019 +CA-CB TYR 1.535 0.022 +CB-CG TYR 1.512 0.015 +CG-CD1 TYR 1.387 0.013 +CG-CD2 TYR 1.387 0.013 +CD1-CE1 TYR 1.389 0.015 +CD2-CE2 TYR 1.389 0.015 +CE1-CZ TYR 1.381 0.013 +CE2-CZ TYR 1.381 0.013 +CZ-OH TYR 1.374 0.017 +N-CA TYR 1.459 0.020 +CA-C TYR 1.525 0.026 +C-O TYR 1.229 0.019 +CA-CB VAL 1.543 0.021 +CB-CG1 VAL 1.524 0.021 +CB-CG2 VAL 1.524 0.021 +N-CA VAL 1.459 0.020 +CA-C VAL 1.525 0.026 +C-O VAL 1.229 0.019 +- + +Angle Residue Mean StdDev +N-CA-CB ALA 110.1 1.4 +CB-CA-C ALA 110.1 1.5 +N-CA-C ALA 111.0 2.7 +CA-C-O ALA 120.1 2.1 +N-CA-CB ARG 110.6 1.8 +CB-CA-C ARG 110.4 2.0 +CA-CB-CG ARG 113.4 2.2 +CB-CG-CD ARG 111.6 2.6 +CG-CD-NE ARG 111.8 2.1 +CD-NE-CZ ARG 123.6 1.4 +NE-CZ-NH1 ARG 120.3 0.5 +NE-CZ-NH2 ARG 120.3 0.5 +NH1-CZ-NH2 ARG 119.4 1.1 +N-CA-C ARG 111.0 2.7 +CA-C-O ARG 120.1 2.1 +N-CA-CB ASN 110.6 1.8 +CB-CA-C ASN 110.4 2.0 +CA-CB-CG ASN 113.4 2.2 +CB-CG-ND2 ASN 116.7 2.4 +CB-CG-OD1 ASN 121.6 2.0 +ND2-CG-OD1 ASN 121.9 2.3 +N-CA-C ASN 111.0 2.7 +CA-C-O ASN 120.1 2.1 +N-CA-CB ASP 110.6 1.8 +CB-CA-C ASP 110.4 2.0 +CA-CB-CG ASP 113.4 2.2 +CB-CG-OD1 ASP 118.3 0.9 +CB-CG-OD2 ASP 118.3 0.9 +OD1-CG-OD2 ASP 123.3 1.9 +N-CA-C ASP 111.0 2.7 +CA-C-O ASP 120.1 2.1 +N-CA-CB CYS 110.8 1.5 +CB-CA-C CYS 111.5 1.2 +CA-CB-SG CYS 114.2 1.1 +N-CA-C CYS 111.0 2.7 +CA-C-O CYS 120.1 2.1 +N-CA-CB GLU 110.6 1.8 +CB-CA-C GLU 110.4 2.0 +CA-CB-CG GLU 113.4 2.2 +CB-CG-CD GLU 114.2 2.7 +CG-CD-OE1 GLU 118.3 2.0 +CG-CD-OE2 GLU 118.3 2.0 +OE1-CD-OE2 GLU 123.3 1.2 +N-CA-C GLU 111.0 2.7 +CA-C-O GLU 120.1 2.1 +N-CA-CB GLN 110.6 1.8 +CB-CA-C GLN 110.4 2.0 +CA-CB-CG GLN 113.4 2.2 +CB-CG-CD GLN 111.6 2.6 +CG-CD-OE1 GLN 121.6 2.0 +CG-CD-NE2 GLN 116.7 2.4 +OE1-CD-NE2 GLN 121.9 2.3 +N-CA-C GLN 111.0 2.7 +CA-C-O GLN 120.1 2.1 +N-CA-C GLY 113.1 2.5 +CA-C-O GLY 120.6 1.8 +N-CA-CB HIS 110.6 1.8 +CB-CA-C HIS 110.4 2.0 +CA-CB-CG HIS 113.6 1.7 +CB-CG-ND1 HIS 123.2 2.5 +CB-CG-CD2 HIS 130.8 3.1 +CG-ND1-CE1 HIS 108.2 1.4 +ND1-CE1-NE2 HIS 109.9 2.2 +CE1-NE2-CD2 HIS 106.6 2.5 +NE2-CD2-CG HIS 109.2 1.9 +CD2-CG-ND1 HIS 106.0 1.4 +N-CA-C HIS 111.0 2.7 +CA-C-O HIS 120.1 2.1 +N-CA-CB ILE 110.8 2.3 +CB-CA-C ILE 111.6 2.0 +CA-CB-CG1 ILE 111.0 1.9 +CB-CG1-CD1 ILE 113.9 2.8 +CA-CB-CG2 ILE 110.9 2.0 +CG1-CB-CG2 ILE 111.4 2.2 +N-CA-C ILE 111.0 2.7 +CA-C-O ILE 120.1 2.1 +N-CA-CB LEU 110.4 2.0 +CB-CA-C LEU 110.2 1.9 +CA-CB-CG LEU 115.3 2.3 +CB-CG-CD1 LEU 111.0 1.7 +CB-CG-CD2 LEU 111.0 1.7 +CD1-CG-CD2 LEU 110.5 3.0 +N-CA-C LEU 111.0 2.7 +CA-C-O LEU 120.1 2.1 +N-CA-CB LYS 110.6 1.8 +CB-CA-C LYS 110.4 2.0 +CA-CB-CG LYS 113.4 2.2 +CB-CG-CD LYS 111.6 2.6 +CG-CD-CE LYS 111.9 3.0 +CD-CE-NZ LYS 111.7 2.3 +N-CA-C LYS 111.0 2.7 +CA-C-O LYS 120.1 2.1 +N-CA-CB MET 110.6 1.8 +CB-CA-C MET 110.4 2.0 +CA-CB-CG MET 113.3 1.7 +CB-CG-SD MET 112.4 3.0 +CG-SD-CE MET 100.2 1.6 +N-CA-C MET 111.0 2.7 +CA-C-O MET 120.1 2.1 +N-CA-CB PHE 110.6 1.8 +CB-CA-C PHE 110.4 2.0 +CA-CB-CG PHE 113.9 2.4 +CB-CG-CD1 PHE 120.8 0.7 +CB-CG-CD2 PHE 120.8 0.7 +CD1-CG-CD2 PHE 118.3 1.3 +CG-CD1-CE1 PHE 120.8 1.1 +CG-CD2-CE2 PHE 120.8 1.1 +CD1-CE1-CZ PHE 120.1 1.2 +CD2-CE2-CZ PHE 120.1 1.2 +CE1-CZ-CE2 PHE 120.0 1.8 +N-CA-C PHE 111.0 2.7 +CA-C-O PHE 120.1 2.1 +N-CA-CB PRO 103.3 1.2 +CB-CA-C PRO 111.7 2.1 +CA-CB-CG PRO 104.8 1.9 +CB-CG-CD PRO 106.5 3.9 +CG-CD-N PRO 103.2 1.5 +CA-N-CD PRO 111.7 1.4 +N-CA-C PRO 112.1 2.6 +CA-C-O PRO 120.2 2.4 +N-CA-CB SER 110.5 1.5 +CB-CA-C SER 110.1 1.9 +CA-CB-OG SER 111.2 2.7 +N-CA-C SER 111.0 2.7 +CA-C-O SER 120.1 2.1 +N-CA-CB THR 110.3 1.9 +CB-CA-C THR 111.6 2.7 +CA-CB-OG1 THR 109.0 2.1 +CA-CB-CG2 THR 112.4 1.4 +OG1-CB-CG2 THR 110.0 2.3 +N-CA-C THR 111.0 2.7 +CA-C-O THR 120.1 2.1 +N-CA-CB TRP 110.6 1.8 +CB-CA-C TRP 110.4 2.0 +CA-CB-CG TRP 113.7 1.9 +CB-CG-CD1 TRP 127.0 1.3 +CB-CG-CD2 TRP 126.6 1.3 +CD1-CG-CD2 TRP 106.3 0.8 +CG-CD1-NE1 TRP 110.1 1.0 +CD1-NE1-CE2 TRP 109.0 0.9 +NE1-CE2-CD2 TRP 107.3 1.0 +CE2-CD2-CG TRP 107.3 0.8 +CG-CD2-CE3 TRP 133.9 0.9 +NE1-CE2-CZ2 TRP 130.4 1.1 +CE3-CD2-CE2 TRP 118.7 1.2 +CD2-CE2-CZ2 TRP 122.3 1.2 +CE2-CZ2-CH2 TRP 117.4 1.0 +CZ2-CH2-CZ3 TRP 121.6 1.2 +CH2-CZ3-CE3 TRP 121.2 1.1 +CZ3-CE3-CD2 TRP 118.8 1.3 +N-CA-C TRP 111.0 2.7 +CA-C-O TRP 120.1 2.1 +N-CA-CB TYR 110.6 1.8 +CB-CA-C TYR 110.4 2.0 +CA-CB-CG TYR 113.4 1.9 +CB-CG-CD1 TYR 121.0 0.6 +CB-CG-CD2 TYR 121.0 0.6 +CD1-CG-CD2 TYR 117.9 1.1 +CG-CD1-CE1 TYR 121.3 0.8 +CG-CD2-CE2 TYR 121.3 0.8 +CD1-CE1-CZ TYR 119.8 0.9 +CD2-CE2-CZ TYR 119.8 0.9 +CE1-CZ-CE2 TYR 119.8 1.6 +CE1-CZ-OH TYR 120.1 2.7 +CE2-CZ-OH TYR 120.1 2.7 +N-CA-C TYR 111.0 2.7 +CA-C-O TYR 120.1 2.1 +N-CA-CB VAL 111.5 2.2 +CB-CA-C VAL 111.4 1.9 +CA-CB-CG1 VAL 110.9 1.5 +CA-CB-CG2 VAL 110.9 1.5 +CG1-CB-CG2 VAL 110.9 1.6 +N-CA-C VAL 111.0 2.7 +CA-C-O VAL 120.1 2.1 +- + +Non-bonded distance Minimum Dist Tolerance +C-C 3.4 1.5 +C-N 3.25 1.5 +C-S 3.5 1.5 +C-O 3.22 1.5 +N-N 3.1 1.5 +N-S 3.35 1.5 +N-O 3.07 1.5 +O-S 3.32 1.5 +O-O 3.04 1.5 +S-S 2.03 1.0 +- diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/common/utils.py b/MindSPONGE/applications/research/Grasp/mindsponge1/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56448775b95bf4df4e118f1a8699b5012cad11ef --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/common/utils.py @@ -0,0 +1,959 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""utils module""" + +import numpy as np +from Bio import Align +from Bio.Align import substitution_matrices +from mindspore import nn +import mindspore.numpy as mnp +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from . import geometry +from . import residue_constants, protein + + + +class MemoryReduceCell(nn.Cell): + def __init__(self, slice_num, device_num, strategy, dim=0, gather_dim=0): + super(MemoryReduceCell, self).__init__() + self.slice_num = slice_num + self.dim = dim + self.strategy = strategy + concat_strategy = [] + for i in range(slice_num): + concat_strategy.append((1, device_num, 1)) + self.concat = P.Concat(gather_dim).shard(tuple(concat_strategy)) + if self.slice_num > 1: + self.gathers = [] + for i in range(len(strategy)): + self.gathers.append(P.Gather().shard(strategy[i])) + + def construct(self, body, batched_inputs, nonbatched_inputs): + if self.slice_num <= 1: + inputs = batched_inputs + nonbatched_inputs + return body(*inputs) + interval = [] + inner_split_res = 0 + for val in batched_inputs: + interval.append(val.shape[self.dim] / self.slice_num) + + inner_split_inputs = () + val_index = 0 + for n in range(len(self.strategy)): + val = batched_inputs[n] + input_indices = mnp.arange(0, interval[val_index]).astype(mnp.int32) + inner_val = self.gathers[n](val, input_indices, self.dim) + inner_split_inputs = inner_split_inputs + (inner_val,) + val_index += 1 + inner_split_inputs = inner_split_inputs + nonbatched_inputs + inner_split_res = body(*inner_split_inputs) + + res = (inner_split_res,) + for i in range(1, self.slice_num): + inner_split_inputs = () + val_index = 0 + for n in range(len(self.strategy)): + val = batched_inputs[n] + input_indices = mnp.arange(i*interval[val_index], (i + 1) * interval[val_index]).astype(mnp.int32) + val = F.depend(val, res[-1]) + inner_val = self.gathers[n](val, input_indices, self.dim) + inner_split_inputs = inner_split_inputs + (inner_val,) + val_index += 1 + inner_split_inputs = inner_split_inputs + nonbatched_inputs + inner_split_inputs = F.depend(inner_split_inputs, res[-1]) + inner_split_res = body(*inner_split_inputs) + res = res + (inner_split_res,) + res = self.concat(res) + return res + + +# class MemoryReduceCell(nn.Cell): +# def __init__(self, slice_num, device_num, strategy=[None, None], dim=[0, 0]): +# super(MemoryReduceCell, self).__init__() +# self.slice_num = slice_num +# self.dim = dim +# self.strategy = strategy +# concat_strategy = [] +# for i in range(slice_num): +# concat_strategy.append((1, 1, device_num)) +# self.concat = P.Concat().shard(tuple(concat_strategy)) +# if self.slice_num > 1: +# self.gathers = [] +# for i in range(len(strategy)): +# self.gathers.append(P.Gather().shard(strategy[i])) + +# def construct(self, body, batched_inputs, nonbatched_inputs): +# if self.slice_num <= 1: +# inputs = batched_inputs + nonbatched_inputs +# return body(*inputs) +# interval = [] +# inner_split_res = 0 +# for i, val in enumerate(batched_inputs): +# interval.append(val.shape[self.dim[i]] / self.slice_num) + +# inner_split_inputs = () +# val_index = 0 +# for n in range(len(self.strategy)): +# val = batched_inputs[n] +# input_indices = mnp.arange(0, interval[val_index]).astype(mnp.int32) +# inner_val = self.gathers[n](val, input_indices, self.dim[n]) +# inner_split_inputs = inner_split_inputs + (inner_val,) +# val_index += 1 +# inner_split_inputs = inner_split_inputs + nonbatched_inputs +# inner_split_res = body(*inner_split_inputs) + +# res = (inner_split_res,) +# for i in range(1, self.slice_num): +# inner_split_inputs = () +# val_index = 0 +# for n in range(len(self.strategy)): +# val = batched_inputs[n] +# input_indices = mnp.arange(i*interval[val_index], (i + 1) * interval[val_index]).astype(mnp.int32) +# val = F.depend(val, res[-1]) +# inner_val = self.gathers[n](val, input_indices, self.dim[n]) +# inner_split_inputs = inner_split_inputs + (inner_val,) +# val_index += 1 +# inner_split_inputs = inner_split_inputs + nonbatched_inputs +# inner_split_inputs = F.depend(inner_split_inputs, res[-1]) +# inner_split_res = body(*inner_split_inputs) +# res = res + (inner_split_res,) +# res = self.concat(res) +# return res + + +def _memory_reduce(body, batched_inputs, nonbatched_inputs, slice_num, dim=0): + """memory reduce function""" + if slice_num <= 1: + inputs = batched_inputs + nonbatched_inputs + return body(*inputs) + inner_batched_inputs = [] + for val in batched_inputs: + inner_val = P.Split(dim, slice_num)(val) + inner_batched_inputs.append(inner_val) + # for depend + inner_split_batched_inputs = () + for j in range(len(inner_batched_inputs)): + inner_split_batched_inputs = inner_split_batched_inputs + (inner_batched_inputs[j][0],) + inner_split_inputs = inner_split_batched_inputs + nonbatched_inputs + inner_split_res = body(*inner_split_inputs) + res = (inner_split_res,) + for i in range(1, slice_num): + inner_split_batched_inputs = () + for j in range(len(inner_batched_inputs)): + inner_split_batched_inputs = inner_split_batched_inputs + (inner_batched_inputs[j][i],) + inner_split_inputs = inner_split_batched_inputs + nonbatched_inputs + inner_split_inputs = F.depend(inner_split_inputs, res[-1]) + inner_split_res = body(*inner_split_inputs) + res = res + (inner_split_res,) + res = P.Concat()(res) + return res + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + """Create pseudo beta features.""" + + is_gly = mnp.equal(aatype, residue_constants.restype_order['G']) + ca_idx = residue_constants.atom_order['CA'] + cb_idx = residue_constants.atom_order['CB'] + pseudo_beta = mnp.where( + mnp.tile(is_gly[..., None], [1,] * len(is_gly.shape) + [3,]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :]) + if all_atom_masks is not None: + pseudo_beta_mask = mnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) + pseudo_beta_mask = pseudo_beta_mask.astype(mnp.float32) + return pseudo_beta, pseudo_beta_mask + return pseudo_beta + + +def dgram_from_positions(positions, num_bins, min_bin, max_bin, ret_type): + """Compute distogram from amino acid positions. + + Arguments: + positions: [N_res, 3] Position coordinates. + num_bins: The number of bins in the distogram. + min_bin: The left edge of the first bin. + max_bin: The left edge of the final bin. The final bin catches + everything larger than `max_bin`. + + Returns: + Distogram with the specified number of bins. + """ + + def squared_difference(x, y): + return mnp.square(x - y) + + lower_breaks = mnp.linspace(min_bin, max_bin, num_bins) + lower_breaks = mnp.square(lower_breaks) + upper_breaks = mnp.concatenate([lower_breaks[1:], mnp.array([1e8], dtype=mnp.float32)], axis=-1) + dist2 = mnp.sum(squared_difference(mnp.expand_dims(positions, axis=-2), + mnp.expand_dims(positions, axis=-3)), axis=-1, keepdims=True) + dgram = ((dist2 > lower_breaks).astype(ret_type) * (dist2 < upper_breaks).astype(ret_type)) + return dgram + + +# class DgramFromPositionsCell(nn.Cell): +# def __init__(self, num_bins, min_bin, max_bin, ret_type, device_num): +# super(DgramFromPositionsCell, self).__init__() +# self.num_bins = num_bins +# self.min_bin = min_bin +# self.max_bin = max_bin +# self.ret_type = ret_type +# self.sum = P.ReduceSum(True).shard(((1, device_num, 1),)) +# self.square = P.Square().shard(((1, device_num, 1),)) +# self.expand = P.ExpandDims().shard(((device_num, 1),)) +# self.greater = P.Greater().shard(((1, device_num, 1), (1,))) +# self.less = P.Less().shard(((1, device_num, 1), (1,))) +# #self.greater.shard(strategy) +# self.mul = P.Mul().shard(((1, device_num, 1), (1, device_num, 1))) +# self.sub = P.Sub()#.shard(((1, 1, 1), (1, device_num, 1))) + +# def construct(self, positions): +# lower_breaks = mnp.linspace(self.min_bin, self.max_bin, self.num_bins) +# lower_breaks = mnp.square(lower_breaks) +# upper_breaks = mnp.concatenate([lower_breaks[1:], mnp.array([1e8], dtype=mnp.float32)], axis=-1) +# dist2 = self.sum(self.square(self.sub(self.expand(positions, -2), self.expand(positions, -3))), -1) +# dgram = self.mul(self.greater(dist2, lower_breaks).astype(self.ret_type), self.less(dist2, upper_breaks).astype(self.ret_type)) +# return dgram + + + +def atom37_to_torsion_angles( + aatype, # (B, N) + all_atom_pos, # (B, N, 37, 3) + all_atom_mask, # (B, N, 37) + chi_atom_indices, + chi_angles_mask, + mirror_psi_mask, + chi_pi_periodic, + indices0, + indices1 +): + """Computes the 7 torsion angles (in sin, cos encoding) for each residue. + + The 7 torsion angles are in the order + '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]', + here pre_omega denotes the omega torsion angle between the given amino acid + and the previous amino acid. + + Args: + aatype: Amino acid type, given as array with integers. + all_atom_pos: atom37 representation of all atom coordinates. + all_atom_mask: atom37 representation of mask on all atom coordinates. + placeholder_for_undefined: flag denoting whether to set masked torsion + angles to zero. + Returns: + Dict containing: + * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final + 2 dimensions denote sin and cos respectively + * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but + with the angle shifted by pi for all chi angles affected by the naming + ambiguities. + * 'torsion_angles_mask': Mask for which chi angles are present. + """ + + # Map aatype > 20 to 'Unknown' (20). + aatype = mnp.minimum(aatype, 20) + + # Compute the backbone angles. + num_batch, num_res = aatype.shape + + pad = mnp.zeros([num_batch, 1, 37, 3], mnp.float32) + prev_all_atom_pos = mnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1) + + pad = mnp.zeros([num_batch, 1, 37], mnp.float32) + prev_all_atom_mask = mnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1) + + # For each torsion angle collect the 4 atom positions that define this angle. + # shape (B, N, atoms=4, xyz=3) + pre_omega_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 1:3, :], all_atom_pos[:, :, 0:2, :]], axis=-2) + phi_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 2:3, :], all_atom_pos[:, :, 0:3, :]], axis=-2) + psi_atom_pos = mnp.concatenate([all_atom_pos[:, :, 0:3, :], all_atom_pos[:, :, 4:5, :]], axis=-2) + # # Collect the masks from these atoms. + # # Shape [batch, num_res] + # ERROR NO PROD + pre_omega_mask = (P.ReduceProd()(prev_all_atom_mask[:, :, 1:3], -1) # prev CA, C + * P.ReduceProd()(all_atom_mask[:, :, 0:2], -1)) # this N, CA + phi_mask = (prev_all_atom_mask[:, :, 2] # prev C + * P.ReduceProd()(all_atom_mask[:, :, 0:3], -1)) # this N, CA, C + psi_mask = (P.ReduceProd()(all_atom_mask[:, :, 0:3], -1) * # this N, CA, C + all_atom_mask[:, :, 4]) # this O + # Collect the atoms for the chi-angles. + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. + atom_indices = mnp.take(chi_atom_indices, aatype, axis=0) + + # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. + + # 4 seq_length 4 4 batch, sequence length, chis, atoms + seq_length = all_atom_pos.shape[1] + atom_indices = atom_indices.reshape((4, seq_length, 4, 4, 1)).astype("int32") + new_indices = P.Concat(4)((indices0, indices1, atom_indices)) # 4, seq_length, 4, 4, 3 + chis_atom_pos = P.GatherNd()(all_atom_pos, new_indices) + chis_mask = mnp.take(chi_angles_mask, aatype, axis=0) + chi_angle_atoms_mask = P.GatherNd()(all_atom_mask, new_indices) + + # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. + chi_angle_atoms_mask = P.ReduceProd()(chi_angle_atoms_mask, -1) + chis_mask = chis_mask * (chi_angle_atoms_mask).astype(mnp.float32) + + # Stack all torsion angle atom positions. + # Shape (B, N, torsions=7, atoms=4, xyz=3)ls + torsions_atom_pos = mnp.concatenate([pre_omega_atom_pos[:, :, None, :, :], + phi_atom_pos[:, :, None, :, :], + psi_atom_pos[:, :, None, :, :], + chis_atom_pos], axis=2) + # Stack up masks for all torsion angles. + # shape (B, N, torsions=7) + torsion_angles_mask = mnp.concatenate([pre_omega_mask[:, :, None], + phi_mask[:, :, None], + psi_mask[:, :, None], + chis_mask], axis=2) + + torsion_rigid = geometry.rigids_from_3_points( + geometry.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]), + geometry.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]), + geometry.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :])) + inv_torsion_rigid = geometry.invert_rigids(torsion_rigid) + forth_atom_rel_pos = geometry.rigids_mul_vecs(inv_torsion_rigid, + geometry.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :])) + # Compute the position of the forth atom in this frame (y and z coordinate + torsion_angles_sin_cos = mnp.stack([forth_atom_rel_pos[2], forth_atom_rel_pos[1]], axis=-1) + torsion_angles_sin_cos /= mnp.sqrt(mnp.sum(mnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8) + # Mirror psi, because we computed it from the Oxygen-atom. + torsion_angles_sin_cos *= mirror_psi_mask + chi_is_ambiguous = mnp.take(chi_pi_periodic, aatype, axis=0) + mirror_torsion_angles = mnp.concatenate([mnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1) + alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) + return torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask + + +def rigids_from_tensor4x4(m): + """Construct Rigids object from an 4x4 array. + + Here the 4x4 is representing the transformation in homogeneous coordinates. + + Args: + m: Array representing transformations in homogeneous coordinates. + Returns: + Rigids object corresponding to transformations m + """ + rotation = (m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], + m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], + m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]) + trans = (m[..., 0, 3], m[..., 1, 3], m[..., 2, 3]) + rigid = (rotation, trans) + return rigid + + +def frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, restype_atom14_to_rigid_group, + restype_atom14_rigid_group_positions, restype_atom14_mask): # (N, 14) + """Put atom literature positions (atom14 encoding) in each rigid group. + + Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11 + + Args: + aatype: aatype for each residue. + all_frames_to_global: All per residue coordinate frames. + Returns: + Positions of all atom coordinates in global frame. + """ + + # Pick the appropriate transform for every atom. + residx_to_group_idx = P.Gather()(restype_atom14_to_rigid_group, aatype, 0) + group_mask = nn.OneHot(depth=8, axis=-1)(residx_to_group_idx) + + # Rigids with shape (N, 14) + map_atoms_to_global = map_atoms_to_global_func(all_frames_to_global, group_mask) + + # Gather the literature atom positions for each residue. + # Vecs with shape (N, 14) + lit_positions = geometry.vecs_from_tensor(P.Gather()(restype_atom14_rigid_group_positions, aatype, 0)) + + # Transform each atom from its local frame to the global frame. + # Vecs with shape (N, 14) + pred_positions = geometry.rigids_mul_vecs(map_atoms_to_global, lit_positions) + + # Mask out non-existing atoms. + mask = P.Gather()(restype_atom14_mask, aatype, 0) + + pred_positions = geometry.vecs_scale(pred_positions, mask) + + return pred_positions + + +def rigids_concate_all(xall, x5, x6, x7): + """rigids concate all.""" + x5 = (geometry.rots_expand_dims(x5[0], -1), geometry.vecs_expand_dims(x5[1], -1)) + x6 = (geometry.rots_expand_dims(x6[0], -1), geometry.vecs_expand_dims(x6[1], -1)) + x7 = (geometry.rots_expand_dims(x7[0], -1), geometry.vecs_expand_dims(x7[1], -1)) + xall_rot = xall[0] + xall_rot_slice = [] + for val in xall_rot: + xall_rot_slice.append(val[:, 0:5]) + xall_trans = xall[1] + xall_trans_slice = [] + for val in xall_trans: + xall_trans_slice.append(val[:, 0:5]) + xall = (xall_rot_slice, xall_trans_slice) + res_rot = [] + for i in range(9): + res_rot.append(mnp.concatenate((xall[0][i], x5[0][i], x6[0][i], x7[0][i]), axis=-1)) + res_trans = [] + for i in range(3): + res_trans.append(mnp.concatenate((xall[1][i], x5[1][i], x6[1][i], x7[1][i]), axis=-1)) + return (res_rot, res_trans) + + +def torsion_angles_to_frames(aatype, backb_to_global, torsion_angles_sin_cos, restype_rigid_group_default_frame): + """Compute rigid group frames from torsion angles.""" + + # Gather the default frames for all rigid groups. + m = P.Gather()(restype_rigid_group_default_frame, aatype, 0) + + default_frames = rigids_from_tensor4x4(m) + + # Create the rotation matrices according to the given angles (each frame is + # defined such that its rotation is around the x-axis). + sin_angles = torsion_angles_sin_cos[..., 0] + cos_angles = torsion_angles_sin_cos[..., 1] + + # insert zero rotation for backbone group. + num_residues, = aatype.shape + sin_angles = mnp.concatenate([mnp.zeros([num_residues, 1]), sin_angles], axis=-1) + cos_angles = mnp.concatenate([mnp.ones([num_residues, 1]), cos_angles], axis=-1) + zeros = mnp.zeros_like(sin_angles) + ones = mnp.ones_like(sin_angles) + + all_rots = (ones, zeros, zeros, + zeros, cos_angles, -sin_angles, + zeros, sin_angles, cos_angles) + + # Apply rotations to the frames. + all_frames = geometry.rigids_mul_rots(default_frames, all_rots) + # chi2, chi3, and chi4 frames do not transform to the backbone frame but to + # the previous frame. So chain them up accordingly. + chi2_frame_to_frame = ((all_frames[0][0][:, 5], all_frames[0][1][:, 5], all_frames[0][2][:, 5], + all_frames[0][3][:, 5], all_frames[0][4][:, 5], all_frames[0][5][:, 5], + all_frames[0][6][:, 5], all_frames[0][7][:, 5], all_frames[0][8][:, 5]), + (all_frames[1][0][:, 5], all_frames[1][1][:, 5], all_frames[1][2][:, 5])) + chi3_frame_to_frame = ((all_frames[0][0][:, 6], all_frames[0][1][:, 6], all_frames[0][2][:, 6], + all_frames[0][3][:, 6], all_frames[0][4][:, 6], all_frames[0][5][:, 6], + all_frames[0][6][:, 6], all_frames[0][7][:, 6], all_frames[0][8][:, 6]), + (all_frames[1][0][:, 6], all_frames[1][1][:, 6], all_frames[1][2][:, 6])) + + chi4_frame_to_frame = ((all_frames[0][0][:, 7], all_frames[0][1][:, 7], all_frames[0][2][:, 7], + all_frames[0][3][:, 7], all_frames[0][4][:, 7], all_frames[0][5][:, 7], + all_frames[0][6][:, 7], all_frames[0][7][:, 7], all_frames[0][8][:, 7]), + (all_frames[1][0][:, 7], all_frames[1][1][:, 7], all_frames[1][2][:, 7])) + + chi1_frame_to_backb = ((all_frames[0][0][:, 4], all_frames[0][1][:, 4], all_frames[0][2][:, 4], + all_frames[0][3][:, 4], all_frames[0][4][:, 4], all_frames[0][5][:, 4], + all_frames[0][6][:, 4], all_frames[0][7][:, 4], all_frames[0][8][:, 4]), + (all_frames[1][0][:, 4], all_frames[1][1][:, 4], all_frames[1][2][:, 4])) + + chi2_frame_to_backb = geometry.rigids_mul_rigids(chi1_frame_to_backb, chi2_frame_to_frame) + chi3_frame_to_backb = geometry.rigids_mul_rigids(chi2_frame_to_backb, chi3_frame_to_frame) + chi4_frame_to_backb = geometry.rigids_mul_rigids(chi3_frame_to_backb, chi4_frame_to_frame) + + # Recombine them to a Rigids with shape (N, 8). + all_frames_to_backb = rigids_concate_all(all_frames, chi2_frame_to_backb, + chi3_frame_to_backb, chi4_frame_to_backb) + + backb_to_global = (geometry.rots_expand_dims(backb_to_global[0], -1), + geometry.vecs_expand_dims(backb_to_global[1], -1)) + # Create the global frames. + all_frames_to_global = geometry.rigids_mul_rigids(backb_to_global, all_frames_to_backb) + return all_frames_to_global + + +def map_atoms_to_global_func(all_frames, group_mask): + """map atoms to global.""" + all_frames_rot = all_frames[0] + all_frames_trans = all_frames[1] + rot = geometry.rots_scale(geometry.rots_expand_dims(all_frames_rot, 1), group_mask) + res_rot = [] + for val in rot: + res_rot.append(mnp.sum(val, axis=-1)) + trans = geometry.vecs_scale(geometry.vecs_expand_dims(all_frames_trans, 1), group_mask) + res_trans = [] + for val in trans: + res_trans.append(mnp.sum(val, axis=-1)) + return (res_rot, res_trans) + + +def atom14_to_atom37(atom14_data, residx_atom37_to_atom14, atom37_atom_exists, indices0): + """Convert atom14 to atom37 representation.""" + + seq_length = atom14_data.shape[0] + residx_atom37_to_atom14 = residx_atom37_to_atom14.reshape((seq_length, 37, 1)) + new_indices = P.Concat(2)((indices0, residx_atom37_to_atom14)) + + atom37_data = P.GatherNd()(atom14_data, new_indices) + + if len(atom14_data.shape) == 2: + atom37_data *= atom37_atom_exists + elif len(atom14_data.shape) == 3: + atom37_data *= atom37_atom_exists[:, :, None].astype(atom37_data.dtype) + + return atom37_data + + +def make_atom14_positions(aatype, all_atom_mask, all_atom_positions): + """ + The function of transforming sparse encoding method to densely encoding method. + + Total coordinate encoding for atoms in proteins comes in two forms. + + - Sparse encoding, 20 amino acids contain a total of 37 atom types as shown in + `common.residue_constants.atom_types`. So coordinates of atoms in protein can be encoded + as a Tensor with shape :math:`(N_{res}, 37, 3)`. + - Densely encoding. 20 amino acids contain a total of 14 atom types as shown in + `common.residue_constants.restype_name_to_atom14_names`. So coordinates of atoms in protein can be encoded + as a Tensor with shape :math:`(N_{res}, 14, 3)`. + + Args: + aatype(numpy.array): Protein sequence encoding. the encoding method refers to + `common.residue_constants.restype_order`. The value ranges from 0 to 20. + 20 means the amino acid is unknown (`UNK`). + all_atom_mask(numpy.array): Mask of coordinates of all atoms in proteins. Shape is + :math:`(N_{res}, 37)`. If the corresponding position is 0, the amino acid + does not contain the atom. + all_atom_positions(numpy.array): Coordinates of all atoms in protein. Shape is :math:`(N_{res}, 37, 3)` . + + Returns: + - numpy.array. Densely encoding, mask of all atoms in protein, including unknown amino acid atoms. + Shape is :math:`(N_{res}, 14)`. + - numpy.array. Densely encoding, mask of all atoms in protein, excluding unknown amino acid atoms. + Shape is :math:`(N_{res}, 14)`. + - numpy.array. Densely encoding, coordinates of all atoms in protein. Shape is :math:`(N_{res}, 14, 3)`. + - numpy.array. Index of mapping sparse encoding atoms with densely encoding method. + Shape is :math:`(N_{res}, 14)` . + - numpy.array. Index of mapping densely encoding atoms with sparse encoding method. + Shape is :math:`(N_{res}, 37)` . + - numpy.array. Sparse encoding, mask of all atoms in protein, including unknown amino acid atoms. + Shape is :math:`(N_{res}, 14)` + - numpy.array. The atomic coordinates after chiral transformation for the atomic coordinates of + densely encoding method. Shape is :math:`(N_{res}, 14, 3)` . + - numpy.array. Atom mask after chiral transformation. Shape is :math:`(N_{res}, 14)` . + - numpy.array. Atom identifier of the chiral transformation. 1 is transformed and 0 is not transformed. + Shape is :math:`(N_{res}, 14)` . + + Symbol: + - ** :math:`N_{res}` ** - The number of amino acids in a protein, according to the sequence of the protein. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindsponge.common import make_atom14_positions + >>> from mindsponge.common import protein + >>> import numpy as np + >>> pdb_path = "YOUR_PDB_FILE" + >>> with open(pdb_path, 'r', encoding = 'UTF-8') as f: + >>> prot_pdb = protein.from_pdb_string(f.read()) + >>> result = make_atom14_positions(prot_pdb.aatype, prot_pdb.atom_mask.astype(np.float32), + >>> prot_pdb.atom_positions.astype(np.float32)) + >>> for val in result: + >>> print(val.shape) + (Nres, 14) + (Nres, 14) + (Nres, 14, 3) + (Nres, 14) + (Nres, 37) + (Nres, 37) + (Nres, 14, 3) + (Nres, 14) + (Nres, 14) + """ + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + + restype_atom14_to_atom37.append([ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ]) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ]) + + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + # Add dummy mapping for restype 'UNK'. + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.] * 14) + + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + # Create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein. + residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype] + residx_atom14_mask = restype_atom14_mask[aatype] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis( + all_atom_mask, residx_atom14_to_atom37, axis=1).astype(np.float32) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * ( + np.take_along_axis(all_atom_positions, residx_atom14_to_atom37[..., None], axis=1)) + + atom14_atom_exists = residx_atom14_mask + atom14_gt_exists = residx_atom14_gt_mask + atom14_gt_positions = residx_atom14_gt_positions + + residx_atom14_to_atom37 = residx_atom14_to_atom37 + + # Create the gather indices for mapping back. + residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype] + + # Create the corresponding mask. + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + atom37_atom_exists = restype_atom37_mask[aatype] + + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] for res in residue_constants.restypes + ] + restype_3 += ["UNK"] + + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names.get(resname).index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names.get(resname).index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[aatype] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = np.einsum("rac,rab->rbc", residx_atom14_gt_positions, renaming_transform) + atom14_alt_gt_positions = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = np.einsum("ra,rab->rb", residx_atom14_gt_mask, renaming_transform) + + atom14_alt_gt_exists = alternative_gt_mask + + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + atom_idx1 = residue_constants.restype_name_to_atom14_names.get(resname).index(atom_name1) + atom_idx2 = residue_constants.restype_name_to_atom14_names.get(resname).index(atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + # From this create an ambiguous_mask for the given sequence. + atom14_atom_is_ambiguous = restype_atom14_is_ambiguous[aatype] + return_pack = (atom14_atom_exists, atom14_gt_exists, atom14_gt_positions, residx_atom14_to_atom37, + residx_atom37_to_atom14, atom37_atom_exists, atom14_alt_gt_positions, atom14_alt_gt_exists, + atom14_atom_is_ambiguous) + return return_pack + + +def get_pdb_info(pdb_path): + """ + get atom positions, residue index etc. info from pdb file. + + Args: + pdb_path(str): the path of the input pdb. + + Returns: + features(dict), the information of pdb, including these keys + + - aatype, numpy.array. Protein sequence encoding. Encoding method refers to + `common.residue_constants_restype_order`, [0:20]. 20 means the amino acid is `UNK`. Shape :math:`(N_{res})` . + - all_atom_positions, numpy.array. Coordinates of all residues in pdb. Shape :math:`(N_{res}, 37)` . + - all_atom_mask, numpy.array. Mask of atoms in pdb. Shape :math:`(N_{res}, 37)` . + 0 means the atom inexistence. + - atom14_atom_exists, numpy.array. Densely encoding, mask of all atoms in protein. + The position with atoms is 1 and the position without atoms is 0. Shape is :math:`(N_{res}, 14)`. + - atom14_gt_exists, numpy.array. Densely encoding, mask of all atoms in protein. + Keep the same as `atom14_atom_exist`. Shape is :math:`(N_{res}, 14)`. + - atom14_gt_positions, numpy.array. Densely encoding, coordinates of all atoms in the protein. + Shape is :math:`(N_{res}, 14, 3)`. + - residx_atom14_to_atom37, numpy.array. Index of mapping sparse encoding atoms with densely encoding method. + Shape is :math:`(N_{res}, 14)` . + - residx_atom37_to_atom14, numpy.array. Index of mapping densely encoding atoms with sparse encoding method. + Shape is :math:`(N_{res}, 37)` . + - atom37_atom_exists, numpy.array. Sparse encoding, mask of all atoms in protein. + The position with atoms is 1 and the position without atoms is 0. Shape is :math:`(N_{res}, 37)`. + - atom14_alt_gt_positions, numpy.array. Densely encoding, coordinates of all atoms in chiral proteins. + Shape is :math:`(N_{res}, 14, 3)` . + - atom14_alt_gt_exists, numpy.array. Densely encoding, mask of all atoms in chiral proteins. + Shape is :math:`(N_{res}, 14)` . + - atom14_atom_is_ambiguous, numpy.array. Because of the local symmetry of some amino acid structures, + the symmetric atomic codes can be transposed. Specific atoms can be found in + `common.residue_atom_renaming_swaps`. This feature records the uncertain atom encoding positions. + Shape is :math:`(N_{res}, 14)` . + - residue_index, numpy.array. Residue index information of protein sequence, ranging from 1 to :math:`N_{res}` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindsponge.common import get_pdb_info + >>> pdb_path = "YOUR PDB PATH" + >>> pdb_feature = get_pdb_info(pdb_path) + >>> for feature in pdb_feature: + >>> print(feature, pdb_feature[feature]) + # Nres represents the Amino acid num of the input pdb. + aatype (Nres,) + all_atom_positions (Nres, 37, 3) + all_atom_mask (Nres, 37) + atom14_atom_exists (Nres, 14) + atom14_gt_exists (Nres, 14) + atom14_gt_positions (Nres, 14, 3) + residx_atom14_to_atom37 (Nres, 14) + residx_atom37_to_atom14 (Nres, 37) + atom37_atom_exists (Nres, 37) + atom14_alt_gt_positions (Nres, 14, 3) + atom14_alt_gt_exists (Nres, 14) + atom14_atom_is_ambiguous (Nres, 14) + residue_index (Nres, ) + + """ + with open(pdb_path, 'r', encoding="UTF-8") as f: + prot_pdb = protein.from_pdb_string(f.read()) + aatype = prot_pdb.aatype + atom37_positions = prot_pdb.atom_positions.astype(np.float32) + atom37_mask = prot_pdb.atom_mask.astype(np.float32) + + # get ground truth of atom14 + features = {'aatype': aatype, + 'all_atom_positions': atom37_positions, + 'all_atom_mask': atom37_mask} + atom14_atom_exists, atom14_gt_exists, atom14_gt_positions, residx_atom14_to_atom37, residx_atom37_to_atom14, \ + atom37_atom_exists, atom14_alt_gt_positions, atom14_alt_gt_exists, atom14_atom_is_ambiguous = \ + make_atom14_positions(aatype, atom37_mask, atom37_positions) + features.update({"atom14_atom_exists": atom14_atom_exists, + "atom14_gt_exists": atom14_gt_exists, + "atom14_gt_positions": atom14_gt_positions, + "residx_atom14_to_atom37": residx_atom14_to_atom37, + "residx_atom37_to_atom14": residx_atom37_to_atom14, + "atom37_atom_exists": atom37_atom_exists, + "atom14_alt_gt_positions": atom14_alt_gt_positions, + "atom14_alt_gt_exists": atom14_alt_gt_exists, + "atom14_atom_is_ambiguous": atom14_atom_is_ambiguous}) + + features["residue_index"] = prot_pdb.residue_index + + return features + + +def get_fasta_info(pdb_path): + """ + Put in a pdb file and get fasta information from it. Return the sequence of the pdb. + + Args: + pdb_path(str): path of the input pdb. + + Returns: + fasta(str), fasta of input pdb. The sequence is the order of residues in the protein and has no + relationship with residue index, such as "GSHMGVQ". + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindsponge.common import get_fasta_info + >>> pdb_path = "YOUR PDB PATH" + >>> fasta = get_fasta_info(pdb_path) + >>> print(fasta) + "GSHMGVQ" + + """ + with open(pdb_path, 'r', encoding='UTF-8') as f: + prot_pdb = protein.from_pdb_string(f.read()) + aatype = prot_pdb.aatype + fasta = [residue_constants.order_restype_with_x.get(x, "X") for x in aatype] + + return ''.join(fasta) + + +def get_aligned_seq(gt_seq, pr_seq): + """ + Align two protein fasta sequence. Return two aligned sequences and the position of same residues. + + Args: + gt_seq(str): one protein fasta sequence, such as "ABAAABAA". + pr_seq(str): another protein fasta sequence, such as "A-AABBBA". + + Returns: + - target(str), one protein fasta sequence. + - align_relationship(str), the differences of the two sequences. + - query(str), another protein fasta sequence. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindsponge.common import get_aligned_seq + >>> gt_seq = "ABAAABAA" + >>> pr_seq = "AAABBBA" + >>> aligned_gt_seq, aligned_info, aligned_pr_seq = get_aligned_seq(gt_seq, pr_seq) + >>> print(aligned_gt_seq) + ABAAABAA + >>> print(aligned_info) + |-||.|.| + >>> print(aligned_pr_seq) + A-AABBBA + + """ + aligner = Align.PairwiseAligner() + substitution_matrices.load() + matrix = substitution_matrices.load("BLOSUM62") + for i in range(len(str(matrix.alphabet))): + res = matrix.alphabet[i] + matrix['X'][res] = 0 + matrix[res]['X'] = 0 + aligner.substitution_matrix = matrix + aligner.open_gap_score = -10 + aligner.extend_gap_score = -1 + # many align results, get only the one w/ highest score. gt_seq as reference + alignments = aligner.align(gt_seq, pr_seq) + align = alignments[0] + align_str = str(align) + align_str_len = len(align_str) + point = [] + target = '' + align_relationship = '' + query = '' + for i in range(align_str_len): + if align_str[i] == '\n': + point.append(i) + for i in range(int(point[0])): + target = target + align_str[i] + for i in range(int(point[1])-int(point[0])-1): + align_relationship = align_relationship + align_str[i + int(point[0])+1] + for i in range(int(point[2])-int(point[1])-1): + query = query + align_str[i + int(point[1])+1] + return target, align_relationship, query + + +def find_optimal_renaming( + atom14_gt_positions, + atom14_alt_gt_positions, + atom14_atom_is_ambiguous, + atom14_gt_exists, + atom14_pred_positions, +): # (N): + """ + Find optimal renaming for ground truth that maximizes LDDT. + + Reference: + `Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" + `_ + + Args: + atom14_gt_positions (Tensor): Ground truth positions in global frame with shape :math:`(N_{res}, 14, 3)`. + atom14_alt_gt_positions (Tensor): Alternate ground truth positions in global frame with coordinates of + ambiguous atoms swapped relative to 'atom14_gt_positions'. + The shape is :math:`(N_{res}, 14, 3)`. + atom14_atom_is_ambiguous (Tensor): Mask denoting whether atom is among ambiguous atoms, + see Jumper et al. (2021) Suppl. Table 3. The shape is :math:`(N_{res}, 14)`. + atom14_gt_exists (Tensor): Mask denoting whether atom at positions exists in ground truth with + shape :math:`(N_{res}, 14)`. + atom14_pred_positions(Tensor): Predicted positions of atoms in global prediction frame with + shape :math:`(N_{res}, 14, 3)`. + + Returns: + Tensor, :math:`(N_{res},)` with 1.0 where atom14_alt_gt_positions is closer to prediction and otherwise 0. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.common.utils import find_optimal_renaming + >>> from mindspore import Tensor + >>> n_res = 16 + >>> atom14_gt_positions = Tensor(np.random.randn(n_res, 14, 3).astype(np.float32)) + >>> atom14_alt_gt_positions = Tensor(np.random.randn(n_res, 14, 3).astype(np.float32)) + >>> atom14_atom_is_ambiguous = Tensor(np.random.randn(n_res, 14).astype(np.float32)) + >>> atom14_gt_exists = Tensor(np.random.randn(n_res, 14).astype(np.float32)) + >>> atom14_pred_positions = Tensor(np.random.randn(n_res, 14, 3).astype(np.float32)) + >>> out = find_optimal_renaming(atom14_gt_positions, atom14_alt_gt_positions, + ... atom14_atom_is_ambiguous, atom14_gt_exists, atom14_pred_positions) + >>> print(out.shape) + (16,) + """ + + # Create the pred distance matrix. + atom14_pred_positions = P.Pad(((0, 0), (0, 0), (0, 5)))(atom14_pred_positions) + pred_dists = mnp.sqrt(1e-10 + mnp.sum( + mnp.square(atom14_pred_positions[:, None, :, None, :] - atom14_pred_positions[None, :, None, :, :]), axis=-1)) + + # Compute distances for ground truth with original and alternative names. + gt_dists = mnp.sqrt(1e-10 + mnp.sum( + mnp.square(atom14_gt_positions[:, None, :, None, :] - atom14_gt_positions[None, :, None, :, :]), axis=-1)) + alt_gt_dists = mnp.sqrt(1e-10 + mnp.sum( + mnp.square(atom14_alt_gt_positions[:, None, :, None, :] - atom14_alt_gt_positions[None, :, None, :, :]), + axis=-1)) + + # Compute LDDT's. + lddt = mnp.sqrt(1e-10 + mnp.square(pred_dists - gt_dists)) + alt_lddt = mnp.sqrt(1e-10 + mnp.square(pred_dists - alt_gt_dists)) + + # Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms + # in cols. + mask = (atom14_gt_exists[:, None, :, None] * # rows + atom14_atom_is_ambiguous[:, None, :, None] * # rows + atom14_gt_exists[None, :, None, :] * # cols + (1. - atom14_atom_is_ambiguous[None, :, None, :])) # cols + + # Aggregate distances for each residue to the non-amibuguous atoms. + per_res_lddt = P.ReduceSum()(mask * lddt, (1, 2, 3)) + alt_per_res_lddt = P.ReduceSum()(mask * alt_lddt, (1, 2, 3)) + + # Decide for each residue, whether alternative naming is better. + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt) + + return alt_naming_is_better \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61c20fc5a70bac8390d74325edf75b8d54391ad9 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Controller""" + +from .controller import Controller +from .integrator import Integrator, LeapFrog, VelocityVerlet, Brownian +from .thermostat import Thermostat, BerendsenThermostat, Langevin +from .barostat import Barostat, BerendsenBarostat +from .constraint import Constraint, Lincs + +__all__ = ['Controller', 'Integrator', 'LeapFrog', 'VelocityVerlet', 'Brownian', + 'Thermostat', 'BerendsenThermostat', 'Langevin', 'Barostat', + 'BerendsenBarostat', 'Constraint', 'Lincs'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..638a6b1e47694af2d104ecd7a82c8a77873dd3d5 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Barostat""" + +from .barostat import Barostat +from .berendsen import BerendsenBarostat + +__all__ = ['Barostat', 'BerendsenBarostat'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/barostat.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/barostat.py new file mode 100644 index 0000000000000000000000000000000000000000..6012d584dcda841f2716aef60e08368c453028ed --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/barostat.py @@ -0,0 +1,177 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Barostat +""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor, Parameter +from mindspore.ops import functional as F + +from .. import Controller +from ...system import Molecule + + +class Barostat(Controller): + r""" + Barostat controller for pressure coupling. + + Args: + system (Molecule): Simulation system. + pressure (float): Reference pressure P_ref (bar) for pressure coupling. + Default: 1 + anisotropic (bool): Whether to perform anisotropic pressure control. + Default: False + control_step (int): Step interval for controller execution. Default: 1 + compressibility (float): Isothermal compressibility \beta (bar^-1). Default: 4.6e-5 + time_constant (float) Time constant \tau_p (ps) for pressure coupling. + Default: 1 + + Returns: + coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + force (Tensor), Tensor of shape (B, A, D). Data type is float. + energy (Tensor), Tensor of shape (B, 1). Data type is float. + kinetics (Tensor), Tensor of shape (B, D). Data type is float. + virial (Tensor), Tensor of shape (B, D). Data type is float. + pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + system: Molecule, + pressure: float = 1, + anisotropic: bool = False, + control_step: int = 1, + compressibility: float = 4.6e-5, + time_constant: float = 1., + ): + + super().__init__( + system=system, + control_step=control_step, + ) + + self.anisotropic = anisotropic + self.kinetic_unit_scale = self.units.kinetic_ref + self.press_unit_scale = self.units.pressure_ref + + self.sens = Tensor(1e8, ms.float32) + self.inv_sens = msnp.reciprocal(self.sens) + + #(B,1) + self.ref_press = Tensor(pressure, ms.float32).reshape(-1, 1) + if self.ref_press.shape[0] != 1 and self.ref_press.shape[0] != self.num_walker: + raise ValueError('The first dimension of "pressure" (' + str(self.ref_press.shape[0]) + + ') does not match the number of multiple walkers ('+str(self.num_walker) + ')!') + + # isothermal compressibility + self.beta = Tensor(compressibility, ms.float32) + + # \tau_t + self.time_constant = Tensor(time_constant, ms.float32).reshape(-1, 1) + if self.time_constant.shape[0] != self.num_walker and self.time_constant.shape[0] != 1: + raise ValueError('The first shape of self.time_constant must equal to 1 or num_walker') + + self.shape = (self.num_walker, self.dimension) + self.change_accumulation = Parameter(msnp.zeros(self.shape), name='change_accumulation', requires_grad=False) + + self.critical_change = 1e-6 + + @property + def pressure(self): + """reference pressure.""" + return self.ref_press + + @property + def compressibility(self): + """isothermal compressibility.""" + return self.beta + + def pressure_scale(self, sim_press: Tensor, ref_press: Tensor, ratio: float = 1) -> Tensor: + """ + calculate the coordinate scale factor for pressure coupling. + + Args: + sim_press (Tensor): The tensor of simulation pressure. + ref_press (Tensor): The tensor of reference pressure. + ratio (float): The ratio used to change the difference of two pressures. Default: 1 + """ + delta_p = ref_press - sim_press + change = - ratio * self.beta * delta_p + + # If the change is too small, the float32 data will not be able to represent the scale. + # Therefore, the small changes will be accumulated: + # (1 + x) ^ n \approx 1 + nx, when x << 1 + # When the total change accumulates to a critical value, then the coordinate and PBC box will be scaled. + change += self.change_accumulation + mask = msnp.abs(change) > self.critical_change + scale = msnp.where(mask, 1+change, 1.) + change = msnp.where(mask, 0., change) + F.depend(True, F.assign(self.change_accumulation, change)) + + return scale + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + r""" + Control the pressure of the simulation system. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + velocity (Tensor): Tensor of shape (B, A, D). Data type is float. + force (Tensor): Tensor of shape (B, A, D). Data type is float. + energy (Tensor): Tensor of shape (B, 1). Data type is float. + kinetics (Tensor): Tensor of shape (B, D). Data type is float. + virial (Tensor): Tensor of shape (B, D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + step (int): Simulation step. Default: 0 + + Returns: + coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + force (Tensor), Tensor of shape (B, A, D). Data type is float. + energy (Tensor), Tensor of shape (B, 1). Data type is float. + kinetics (Tensor), Tensor of shape (B, D). Data type is float. + virial (Tensor), Tensor of shape (B, D). Data type is float. + pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Symbols: + B: Number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/berendsen.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/berendsen.py new file mode 100644 index 0000000000000000000000000000000000000000..490e49293a6dd1c1cc790f4b1f45dca4ce7321ad --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/barostat/berendsen.py @@ -0,0 +1,124 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Berendsen barostat""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore.ops import functional as F + +from . import Barostat +from ...system import Molecule + + +class BerendsenBarostat(Barostat): + r""" + A Berendsen (weak coupling) barostat controller. + + Reference: + `Berendsen, H. J. C.; Postma, J. P. M.; van Gunsteren, W. F.; DiNola, A.; Haak, J. R.. + Molecular Dynamics with Coupling to an External Bath [J]. + The Journal of Chemical Physics, 1984, 81(8): 3684. + `_. + + Args: + system (Molecule): Simulation system. + pressure (float): Reference pressure P_ref (bar) for pressure coupling. + Default: 1 + anisotropic (bool): Whether to perform anisotropic pressure control. + Default: False + control_step (int): Step interval for controller execution. Default: 1 + compressibility (float): Isothermal compressibility \beta (bar^-1). Default: 4.6e-5 + time_constant (float): Time constant \tau_p (ps) for pressure coupling. + Default: 1 + + Returns: + coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + force (Tensor), Tensor of shape (B, A, D). Data type is float. + energy (Tensor), Tensor of shape (B, 1). Data type is float. + kinetics (Tensor), Tensor of shape (B, D). Data type is float. + virial (Tensor), Tensor of shape (B, D). Data type is float. + pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + system: Molecule, + pressure: float = 1, + anisotropic: bool = False, + control_step: int = 1, + compressibility: float = 4.6e-5, + time_constant: float = 1., + ): + + super().__init__( + system=system, + pressure=pressure, + anisotropic=anisotropic, + control_step=control_step, + compressibility=compressibility, + time_constant=time_constant, + ) + + self.ratio = self.control_step * self.time_step / self.time_constant / 3. + + def set_time_step(self, dt: float): + """ + set simulation time step. + + Args: + dt (float): Time of a time step. + """ + self.time_step = Tensor(dt, ms.float32) + self.ratio = self.control_step * self.time_step / self.time_constant / 3. + return self + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + if self.control_step == 1 or step % self.control_step == 0: + pressure = self.get_pressure(kinetics, virial, pbc_box) + if not self.anisotropic: + # (B,1) <- (B,D): + pressure = msnp.mean(pressure, axis=-1, keepdims=True) + # (B,D) <- (B,1): + pressure = msnp.broadcast_to(pressure, self.shape) + # (B,D): + scale = self.pressure_scale(pressure, self.ref_press, self.ratio) + + # (B,A,D) * (B,1,D): + coordinate *= scale * F.expand_dims(scale, -2) + # (B,D): + pbc_box *= scale + + return coordinate, velocity, force, energy, kinetics, virial, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab17ec9904766bc82f4ab4d20d34c07c81a34e55 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""constraint""" + +from .constraint import Constraint +from .lincs import Lincs + +__all__ = ['Constraint', 'Lincs'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/constraint.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..9baa43082bc8ec71ebc2b15ddefd57b79d361f1d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/constraint.py @@ -0,0 +1,154 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Constraint +""" + +import numpy as np +import mindspore as ms +from mindspore import Tensor, Parameter + +from .. import Controller +from ...system import Molecule +from ...potential import PotentialCell +from ...function.operations import GetVector, GetDistance + + +class Constraint(Controller): + r""" + Constraint for bonds. + + Args: + system (Molecule): Simulation system. + bonds (Tensor or str): Bonds to be constraint. + Tensor of shape (K, 2). Data type is int. + Alternative: "h-bonds" or "all-bonds". + potential (PotentialCell): Potential Cell. Default: None + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - nergy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + system: Molecule, + bonds: Tensor = 'h-bonds', + potential: PotentialCell = None, + ): + + super().__init__( + system=system, + control_step=1, + ) + + if potential is None: + self.all_bonds = system.bond + self.h_bonds = system.hydrogen_bond + else: + self.all_bonds = potential.bond + self.h_bonds = potential.hydrogen_bond + + if isinstance(bonds, (Tensor, Parameter, np.ndarray)): + self.bonds = Tensor(bonds, ms.int32) + elif isinstance(bonds, str): + if bonds.lower() == 'h-bonds': + self.bonds = self.h_bonds + elif bonds.lower() == 'all-bonds': + self.bonds = self.all_bonds + else: + raise ValueError( + '"bonds" must be "h-bonds" or "all-bonds" but got: '+bonds) + else: + raise TypeError( + 'The type of "bonds" must be Tensor or str, but got: '+str(type(bonds))) + + if self.bonds.ndim != 2: + if self.bonds.ndim != 3: + raise ValueError( + 'The rank of "bonds" must be 2 or 3 but got: '+str(self.bonds.ndim)) + + if self.bonds.shape[0] != 1: + raise ValueError('For constraint, the batch size of "bonds" must be 1 but got: ' + + str(self.bonds[0])) + self.bonds = self.bonds[0] + + if self.bonds.shape[-1] != 2: + raise ValueError( + 'The last dimension of "bonds" but got: '+str(self.bonds.shape[-1])) + + # C + self.num_constraints = self.bonds.shape[-2] + + self.use_pbc = self._pbc_box is not None + + self.get_vector = GetVector(self.use_pbc) + self.get_distance = GetDistance(use_pbc=self.use_pbc) + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + """ + constraint the bonds. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + velocity (Tensor): Tensor of shape (B, A, D). Data type is float. + force (Tensor): Tensor of shape (B, A, D). Data type is float. + energy (Tensor): Tensor of shape (B, 1). Data type is float. + kinetics (Tensor): Tensor of shape (B, D). Data type is float. + virial (Tensor): Tensor of shape (B, D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + step (int): Simulation step. Default: 0 + + Returns: + coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + force (Tensor), Tensor of shape (B, A, D). Data type is float. + energy (Tensor), Tensor of shape (B, 1). Data type is float. + kinetics (Tensor), Tensor of shape (B, D). Data type is float. + virial (Tensor), Tensor of shape (B, D). Data type is float. + pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Symbols: + B: Number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/lincs.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/lincs.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce47a9601eb3d50ffaf6f932beb7d4b3bd1ec9c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/constraint/lincs.py @@ -0,0 +1,205 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +LINCS Constraint algorithm +""" + +import numpy as np +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore import ops +from mindspore.ops import functional as F + +from . import Constraint +from ...system import Molecule +from ...potential import PotentialCell +from ...function.operations import GetShiftGrad + + +class Lincs(Constraint): + """ + LINCS (LINear Constraint Solver) constraint controller. + + Args: + system (Molecule): Simulation system. + bonds (Tensor): Bonds to be constraint. + Tensor of shape (B, 2). Data type is int. + Default: "h-bonds". + potential (PotentialCell): Potential Cell. Default: None + + Inputs: + - **coordinate** (Tensor) - The coordinates of the system. + - **velocity** (Tensor) - The velocity of the system. + - **force** (Tensor) - The force of the system. + - **energy** (Tensor) - The energy of the system. + - **kinetics** (Tensor) - The kinetics of the system. + - **virial** (Tensor) - The virial of the system. Default: None + - **pbc_box** (Tensor) - PBC box of the system. Default: None + - **step** (int) - The step of the system. Default: 0 + + Return: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + system: Molecule, + bonds: Tensor = 'h-bonds', + potential: PotentialCell = None, + ): + + super().__init__( + system=system, + bonds=bonds, + potential=potential, + ) + #pylint: disable=invalid-name + + # (A,A) <- (A,A) + iinvM = msnp.identity(self.num_atoms) + + # (B,A,A) = (1,A,A) * (B,1,A) + self.Mii = msnp.broadcast_to( + iinvM, (1,) + iinvM.shape) * self.inv_mass[:, None, :] + + self.BMatrix = GetShiftGrad( + num_atoms=self.num_atoms, + bonds=self.bonds, + num_walkers=self.num_walker, + dimension=self.dimension, + use_pbc=self.use_pbc + ) + # (B,C,A,D) + shape = (self.num_walker, + self.bonds.shape[-2], self.num_atoms, self.dimension) + + self.broadcast = ops.BroadcastTo(shape) + self.inv = ops.MatrixInverse(adjoint=False) + self.squeeze = ops.Squeeze() + self.einsum0 = ops.Einsum('ijk,ilkm->iljm') + self.einsum1 = ops.Einsum('ijkl,imkl->ijm') + self.einsum2 = ops.Einsum('ijkl,ikl->ij') + self.einsum3 = ops.Einsum('ijk,ik->ij') + self.einsum4 = ops.Einsum('ijkl,ij->ikl') + self.einsum5 = ops.Einsum('ijk,ikl->ijl') + + # (B,C,A) + shape = (self.num_walker, self.num_constraints, self.num_atoms) + + # (1,C,1) + bond0 = self.bonds[..., 0].reshape(1, -1, 1).asnumpy() + # (B,C,A) <- (B,A,1) + mask0 = np.zeros(shape) + np.put_along_axis(mask0, bond0, 1, axis=-1) + # (B,C,A,1) + self.mask0 = F.expand_dims(Tensor(mask0, ms.int32), -1) + + # (1,C,1) + bond1 = self.bonds[..., 1].reshape(1, -1, 1).asnumpy() + # (B,C,A) <- (B,A,1) + mask1 = np.zeros(shape) + np.put_along_axis(mask1, bond1, 1, axis=-1) + # (B,C,A,1) + self.mask1 = F.expand_dims(Tensor(mask1, ms.int32), -1) + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + """ Construct function of Lincs""" + #pylint: disable=invalid-name + + # (B,A,D) + coordinate_old = self._coordinate + coordinate_new = coordinate + + # (B,C,A,D) + BMatrix = self.BMatrix(coordinate_new, coordinate_old, pbc_box) + + # ijk,ilkm->iljm + # (B,A,A),(B,C,A,D)->(B,C,A,D) + # (B,1,A,A,1),(B,C,1,A,D)->(B,C,A,'A',D)->(B,C,A,D) + tmp0 = self.einsum0((self.Mii, BMatrix)) + + # ijkl,imkl->ijm + # (B,C,A,D),(B,C,A,D)->(B,C,C) + # (B,C,A,D),(B,A,C,D)->(B,C,A,1,D),(B,1,A,C,D)->(B,C,'A',C,'D')->(B,C,C) + tmp1 = self.einsum1((BMatrix, tmp0)) + # (B,C,C) + tmp2 = self.inv(tmp1) + + # (B,1,A,D) <- (B,A,D) + pos_old = self.broadcast(F.expand_dims(coordinate_old, -3)) + # (B,C,D) <- (B,C,A,D) = (B,C,A,1) * (B,1,A,D) + pos_old_0 = F.reduce_sum(self.mask0 * pos_old, -2) + pos_old_1 = F.reduce_sum(self.mask1 * pos_old, -2) + # (B,C) + di = self.get_distance(pos_old_0, pos_old_1, pbc_box) + + # ijkl,ikl->ij + # (B,C,A,D),(B,A,D)->(B,C) + # (B,C,A,D),(B,1,A,D)->(B,C,A,D)->(B,C) + tmp3 = self.einsum2((BMatrix, coordinate_new)) - di + + # ijk,ik->ij + # (B,C,C),(B,C)->(B,C) + # (B,C,C),(B,1,C)->(B,C,'C')->(B,C) + tmp4 = self.einsum3((tmp2, tmp3)) + + # ijkl,ij->ikl + # (B,C,A,D),(B,C)->(B,A,D) + # (B,A,C,D),(B,1,C,1)->(B,A,C,D)->(B,A,D) + tmp5 = self.einsum4((BMatrix, tmp4)) + + # ijk,ikl->ijl + # (B,A,A),(B,A,D)->(B,A,D) + # (B,A,A,1),(B,1,A,D)->(B,A,'A',D)->(B,A,D) + dr = -self.einsum5((self.Mii, tmp5)) + coordinate = coordinate_new + dr + + # (B,A,D) + velocity += dr / self.time_step + # Constraint force = m * dR / dt^2 + # (B,A,1) * (B,A,D) + constraint_force = self._atom_mass * dr / (self.time_step**2) + force += constraint_force + if self._pbc_box is not None: + # (B,D) <- (B,A,D) + virial += F.reduce_sum(-0.5 * coordinate * constraint_force, -2) + + return coordinate, velocity, force, energy, kinetics, virial, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/controller.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..eea85fe96850ca1bf4b7721104533e8fa0c6a1ed --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/controller.py @@ -0,0 +1,292 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Controller +""" + +import mindspore as ms +from mindspore import Tensor +from mindspore.nn import Cell +from mindspore import ops +from mindspore.ops import functional as F + +from ..system import Molecule +from ..function import functions as func +from ..function.functions import get_integer + + +class Controller(Cell): + r""" + The controller for control the parameters in the simulation process, + including integrator, thermostat, barostat, constraint, etc. + + Args: + system (Molecule): Simulation system. + control_step (int): Step interval for controller execution. Default: 1 + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + system: Molecule, + control_step: int = 1, + ): + + super().__init__(auto_prefix=False) + + self.system = system + self.num_walker = self.system.num_walker + self.num_atoms = system.num_atoms + self.dimension = system.dimension + + self.sys_dofs = system.degrees_of_freedom + self.degrees_of_freedom = system.degrees_of_freedom + + self.time_step = Tensor(1e-3, ms.float32) + + self._coordinate = self.system.coordinate + self._pbc_box = self.system.pbc_box + + self.units = self.system.units + self.boltzmann = self.units.boltzmann + self.kinetic_unit_scale = self.units.kinetic_ref + self.press_unit_scale = self.units.pressure_ref + + # (B,A) + self.atom_mass = self.system.atom_mass + self.inv_mass = self.system.inv_mass + # (B,A,1) + self._atom_mass = F.expand_dims(self.atom_mass, -1) + self._inv_mass = F.expand_dims(self.inv_mass, -1) + + # (B,1) + self.system_mass = self.system.system_mass + self.system_natom = self.system.system_natom + + self.control_step = get_integer(control_step) + if self.control_step <= 0: + raise ValueError('The "control_step" must be larger than 0!') + + self.num_constraints = 0 + + self.identity = ops.Identity() + self.keepdim_sum = ops.ReduceSum(keep_dims=True) + + def set_time_step(self, dt: float): + """ + set simulation time step. + + Args: + dt (float): Time of a time step. + """ + self.time_step = Tensor(dt, ms.float32) + return self + + def set_degrees_of_freedom(self, dofs: int): + """ + set degrees of freedom (DOFs). + + Args: + dofs (int): degrees of freedom. + """ + self.degrees_of_freedom = get_integer(dofs) + return self + + def update_coordinate(self, coordinate: Tensor, success: bool = True) -> bool: + """ + update the parameter of coordinate. + + Args: + coordinate (Tensor): A tensor of parameters of coordinate. + success (bool): Whether update the parameters successfully. + + Returns: + bool. + """ + success = F.depend(success, F.assign(self._coordinate, coordinate)) + return success + + def update_pbc_box(self, pbc_box: Tensor, success: bool = True) -> bool: + """ + update the parameter of PBC box. + + Args: + pbc_box (Tensor): A tensor of parameters of PBC box. + success (bool): Whether update the parameters successfully. + + Returns: + bool. + """ + if self._pbc_box is None: + return success + return F.depend(success, F.assign(self._pbc_box, pbc_box)) + + def get_kinetics(self, velocity: Tensor) -> Tensor: + """ + calculate kinetics according to velocity. + + Args: + velocity (Tensor): A tensor of velocity. + + Returns: + Tensor, kinetics according to velocity. + """ + if velocity is None: + return None + # (B,A,D) * (B,A,1) + k = 0.5 * self._atom_mass * velocity**2 + # (B,D) <- (B,A,D) + kinetics = F.reduce_sum(k, -2) + return kinetics * self.kinetic_unit_scale + + def get_temperature(self, kinetics: Tensor = None) -> Tensor: + """ + calculate temperature according to velocity. + + Args: + kinetics (Tensor): A tensor of kinetics. + + Returns: + Tensor, temperature according to velocity. + """ + if kinetics is None: + return None + # (B) <- (B,D) + kinetics = F.reduce_sum(kinetics, -1) + return 2 * kinetics / self.degrees_of_freedom / self.boltzmann + + def get_volume(self, pbc_box: Tensor) -> Tensor: + """ + calculate volume according to PBC box. + + Args: + pbc_box (Tensor): A PBC box tensor used to calculate volume. + + Returns: + Tensor, volume according to PBC box. + """ + if self._pbc_box is None: + return None + # (B,1) <- (B,D) + return func.keepdim_prod(pbc_box, -1) + + def get_virial(self, pbc_grad, pbc_box): + """ + calculate virial according to the PBC box and its gradients. + + Args: + pbc_grad (Tensor): Tensor of PBC box's gradients. + pbc_box (Tensor): Tensor of PBC box + + Returns: + Tensor, virial. + """ + # (B,D) + return 0.5 * pbc_grad * pbc_box + + def get_pressure(self, kinetics: Tensor, virial: Tensor, pbc_box: Tensor) -> Tensor: + """ + calculate pressure according to kinetics, virial and PBC box. + + Args: + kinetics (Tensor): Tensor of kinetics. + virials (Tensor): Tensor of virials. + pbc_box (Tensor): Tensor of PBC box. + + Returns: + Tensor, pressure according to kinetics, viral and PBC box. + """ + if self._pbc_box is None: + return None + volume = func.keepdim_prod(pbc_box, -1) + # (B,D) = ((B,D) - (B, D)) / (B,1) + pressure = 2 * (kinetics - virial) / volume + return pressure * self.press_unit_scale + + def get_com(self, coordinate: Tensor) -> Tensor: + """ + get coordinate of center of mass. + + Args: + coordinate (Tensor): Tensor of coordinate. + + Returns: + Tensor, coordinate of center of mass. + """ + return self.keepdim_sum(coordinate * self._atom_mass, -2) / F.expand_dims(self.system_mass, -1) + + def get_com_velocity(self, velocity: Tensor) -> Tensor: + """ + calculate velocity of center of mass. + + Args: + velocity (Tensor): Tensor of velocity. + + Returns: + Tensor, velocity of center of mass. + """ + # (B,A,D) * (B,A,1) -> (B,1,D) + # (B,1,D) / (B,1,1) + return self.keepdim_sum(velocity * self._atom_mass, -2) / F.expand_dims(self.system_mass, -1) + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + r""" + Control the parameters during the simulation. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + velocity (Tensor): Tensor of shape (B, A, D). Data type is float. + force (Tensor): Tensor of shape (B, A, D). Data type is float. + energy (Tensor): Tensor of shape (B, 1). Data type is float. + kinetics (Tensor): Tensor of shape (B, D). Data type is float. + virial (Tensor): Tensor of shape (B, D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + step (int): Simulation step. Default: 0 + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Symbols: + B: Number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + """ + #pylint: disable=unused-argument + + return coordinate, velocity, force, energy, kinetics, virial, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2d6ef7797422fb1db11fee28e6e9c5b27368c3 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Simulation integrator""" + +from .integrator import Integrator +from .leapfrog import LeapFrog +from .velocityverlet import VelocityVerlet +from .brownian import Brownian + +__all__ = ['Integrator', 'LeapFrog', 'VelocityVerlet', 'Brownian'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/brownian.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/brownian.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbca23836b3d8f1d578c5e9e7747c756ad0de50 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/brownian.py @@ -0,0 +1,151 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Brownian integrator +""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore import ops +from mindspore.ops import functional as F + +from .integrator import Integrator +from ...system import Molecule + + +class Brownian(Integrator): + r""" + Brownian integrator. + + Args: + system (Molecule): Simulation system. + temperature (float): Simulation temperature T (K). Default: 300 + friction_coefficient (float): Friction coefficient g (amu/ps). Default: 1e3 + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + system: Molecule, + temperature: float = 300, + friction_coefficient: float = 1e3, + ): + + super().__init__( + system=system, + thermostat=None, + barostat=None, + constraint=None, + ) + + self.ref_temp = Tensor(temperature, ms.float32) + + self.inv_sqrt_mass = F.sqrt(self._inv_mass) + + self.friction_coefficient = Tensor(friction_coefficient, ms.float32) + # \gamma = 1.0 / \tau_t + self.inv_gamma = msnp.reciprocal(self.friction_coefficient) * self._inv_mass + + # k = \sqrt(2 * k_B * T * dt / \gamma) + self.random_scale = F.sqrt(2 * self.boltzmann * self.ref_temp * self.time_step + * self.inv_gamma / self.kinetic_unit_scale) + + self.normal = ops.StandardNormal() + + self.concat_last_dim = ops.Concat(axis=-1) + self.concat_penulti = ops.Concat(axis=-2) + self.keep_mean = ops.ReduceMean(keep_dims=True) + + @property + def temperature(self) -> Tensor: + return self.ref_temp + + def set_thermostat(self, thermostat: None = None): + """ + set thermostat algorithm for integrator. + + Args: + thermostat (None): Set thermostat algorithm. Default: None + """ + if thermostat is not None: + raise ValueError('The Brownian integrator cannot accept thermostat') + return self + + def set_barostat(self, barostat: None = None): + """ + set barostat algorithm for integrator. + + Args: + barostat (None): Set barostat algorithm. Default: None + """ + if barostat is not None: + raise ValueError('The Brownian integrator cannot accept barostat') + return self + + def set_constraint(self, constraint: None = None): + """ + set constraint algorithm for integrator. + + Args: + constraint (None): Set constraint algorithm. Default: None + """ + if constraint is not None: + raise ValueError('The Brownian integrator cannot accept constraint') + return self + + def set_time_step(self, dt: float): + """ + set simulation time step. + + Args: + dt (float): Time of a time step. + """ + self.time_step = Tensor(dt, ms.float32) + self.random_scale = F.sqrt(2 * self.boltzmann * self.ref_temp * self.time_step * self.inv_gamma) + return self + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + coordinate += self.acc_unit_scale * force * self.inv_gamma * self.time_step + coordinate += self.normal(coordinate.shape) * self.random_scale + + return coordinate, velocity, force, energy, kinetics, virial, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/integrator.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/integrator.py new file mode 100644 index 0000000000000000000000000000000000000000..5d50dda73a0daffbe74b5813708ad96a9a96df39 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/integrator.py @@ -0,0 +1,251 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Integrator +""" + +import mindspore as ms +from mindspore import Tensor +from mindspore.nn import CellList + +from .. import Controller +from ..thermostat import Thermostat +from ..barostat import Barostat +from ..constraint import Constraint +from ...system import Molecule +from ...function.functions import get_integer + + +class Integrator(Controller): + r""" + Integrator for simulation. + + Args: + system (Molecule): Simulation system. + thermostat (Thermostat): Thermostat for temperature coupling. Default: None + barostat (Barostat): Barostat for pressure coupling. Default: None + constraint (Constraint): Constraint algorithm. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + system: Molecule, + thermostat: Thermostat = None, + barostat: Barostat = None, + constraint: Constraint = None, + ): + + super().__init__( + system=system, + control_step=1, + ) + + self.kinetic_unit_scale = Tensor(self.units.kinetic_ref, ms.float32) + self.acc_unit_scale = Tensor(self.units.acceleration_ref, ms.float32) + + self.boltzmann = self.units.boltzmann + self.degrees_of_freedom = self.degrees_of_freedom + + self.thermostat = None + self.set_thermostat(thermostat) + + self.barostat = None + self.set_barostat(barostat) + + self.constraint = None + self.num_constraint_controller = 0 + self.set_constraint(constraint) + + def set_time_step(self, dt: float): + """ + set simulation time step. + + Args: + dt (float): Time of a time step. + """ + self.time_step = Tensor(dt, ms.float32) + if self.thermostat is not None: + self.thermostat.set_time_step(dt) + if self.barostat is not None: + self.barostat.set_time_step(dt) + if self.constraint is not None: + for i in range(self.num_constraint_controller): + self.constraint[i].set_time_step(dt) + return self + + def set_degrees_of_freedom(self, dofs: int): + """ + set degrees of freedom (DOFs) + + Args: + dofs (int): Degrees of freedom. + """ + self.degrees_of_freedom = get_integer(dofs) + if self.thermostat is not None: + self.thermostat.set_degrees_of_freedom(dofs) + if self.barostat is not None: + self.barostat.set_degrees_of_freedom(dofs) + if self.constraint is not None: + for i in range(self.num_constraint_controller): + self.constraint[i].set_degrees_of_freedom(dofs) + return self + + def set_thermostat(self, thermostat: Thermostat): + """ + set thermostat algorithm for integrator. + + Args: + thermostat (Thermostat): The thermostat. + """ + if self.thermostat is not None: + print('Warning! The thermostat for this integrator has already been set to "' + + str(self.thermostat.cls_name)+'" but will now be changed to "'+str(thermostat.cls_name)+'".') + if thermostat is None: + self.thermostat = None + else: + self.thermostat = thermostat + self.thermostat.set_degrees_of_freedom(self.degrees_of_freedom) + self.thermostat.set_time_step(self.time_step) + return self + + def set_barostat(self, barostat: Barostat): + """ + set barostat algorithm for integrator. + + Args: + barostat (Barostat): The barostat. + """ + if self.barostat is not None: + print('Warning! The barostat for this integrator has already been set to "' + + str(self.barostat.cls_name)+'" but will now be changed to "'+str(barostat.cls_name)+'".') + if barostat is None: + self.barostat = None + else: + self.barostat = barostat + self.barostat.set_degrees_of_freedom(self.degrees_of_freedom) + self.barostat.set_time_step(self.time_step) + return self + + def set_constraint(self, constraint: Constraint): + """ + set constraint algorithm for integrator. + + Args: + constraint (Constraint): The constraints. + """ + if self.constraint is not None: + print('Warning! The constraint for this integrator has already been set to "' + + str(self.constraint.cls_name)+'" but will now be changed to "'+str(constraint.cls_name)+'".') + self.num_constraints = 0 + if constraint is None: + self.constraint = None + self.num_constraint_controller = 0 + else: + if isinstance(constraint, Controller): + self.num_constraint_controller = 1 + constraint = [constraint] + elif isinstance(constraint, list): + self.num_constraint_controller = len(constraint) + else: + raise ValueError('The type of "constraint" must be Controller or list but got: ' + + str(type(constraint))) + + self.constraint = CellList(constraint) + for i in range(self.num_constraint_controller): + self.num_constraints += self.constraint[i].num_constraints + self.constraint[i].set_time_step(self.time_step) + degrees_of_freedom = self.sys_dofs - self.num_constraints + self.set_degrees_of_freedom(degrees_of_freedom) + + return self + + def add_constraint(self, constraint: Constraint): + """ + add constraint algorithm for integrator. + + Args: + constraint (Constraint): The constraints. + """ + if isinstance(constraint, Controller): + constraint = [constraint] + num_constraint_controller = 1 + elif isinstance(constraint, list): + num_constraint_controller = len(constraint) + else: + raise ValueError('The type of "constraint" must be Controller or list but got: ' + + str(type(constraint))) + + if self.constraint is None: + return self.set_constraint(constraint) + + self.num_constraint_controller += num_constraint_controller + self.constraint.extend(constraint) + for i in range(self.num_constraint_controller): + self.num_constraints += self.constraint[i].num_constraints + self.constraint[i].set_time_step(self.time_step) + degrees_of_freedom = self.sys_dofs - self.num_constraints + self.set_degrees_of_freedom(degrees_of_freedom) + + return self + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + r""" + update simulation step + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + velocity (Tensor): Tensor of shape (B, A, D). Data type is float. + force (Tensor): Tensor of shape (B, A, D). Data type is float. + energy (Tensor): Tensor of shape (B, 1). Data type is float. + kinetics (Tensor): Tensor of shape (B, D). Data type is float. + virial (Tensor): Tensor of shape (B, D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + step (int): Simulation step. Default: 0 + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Symbols: + B: Number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + """ + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/leapfrog.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/leapfrog.py new file mode 100644 index 0000000000000000000000000000000000000000..0bccda857fb36b07bf4866c08aa6babe9af1751e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/leapfrog.py @@ -0,0 +1,118 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Leap-frog integrator +""" + +from mindspore import Tensor + +from .integrator import Integrator +from ..thermostat import Thermostat +from ..barostat import Barostat +from ..constraint import Constraint +from ...system import Molecule + + +class LeapFrog(Integrator): + r""" + A leap-frog integrator based on "middle scheme" developed by Jian Liu, et al. + + Reference: + `Zhang, Z.; Yan, K; Liu, X.; Liu, J.. + A Leap-Frog Algorithm-based Efficient Unified Thermostat Scheme for Molecular Dynamics [J]. + Chinese Science Bulletin, 2018, 63(33): 3467-3483. + `_. + + Args: + system (Molecule): Simulation system. + thermostat (Thermostat): Thermostat for temperature coupling. Default: None + barostat (Barostat): Barostat for pressure coupling. Default: None + constraint (Constraint): Constraint algorithm. Default: None + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity_half (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + system: Molecule, + thermostat: Thermostat = None, + barostat: Barostat = None, + constraint: Constraint = None, + ): + + super().__init__( + system=system, + thermostat=thermostat, + barostat=barostat, + constraint=constraint, + ) + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + # (B,A,D) = (B,A,D) * (B,A,1) + acceleration = self.acc_unit_scale * force * self._inv_mass + + # v(t+0.5) = v(t-0.5) + a(t) * dt + velocity_half = velocity + acceleration * self.time_step + # (B,A,D) = (B,A,D) - (B,1,D) + velocity_half -= self.get_com_velocity(velocity_half) + kinetics = self.get_kinetics(velocity_half) + + # R(t+0.5) = R(t) + v(t+0.5) * dt + coordinate_half = coordinate + velocity_half * self.time_step * 0.5 + + if self.thermostat is not None: + # v'(t+0.5) = f_T[v(t+0.5)] + coordinate_half, velocity_half, force, energy, kinetics, virial, pbc_box = \ + self.thermostat(coordinate_half, velocity_half, force, energy, kinetics, virial, pbc_box, step) + + # R(t+1) = R(t+0.5) + v'(t+0.5) * dt + coordinate_new = coordinate_half + velocity_half * self.time_step * 0.5 + + if self.constraint is not None: + for i in range(self.num_constraint_controller): + coordinate_new, velocity_half, force, energy, kinetics, virial, pbc_box = \ + self.constraint[i](coordinate_new, velocity_half, force, energy, kinetics, virial, pbc_box, step) + + if self.barostat is not None: + coordinate_new, velocity_half, force, energy, kinetics, virial, pbc_box = \ + self.barostat(coordinate_new, velocity_half, force, energy, kinetics, virial, pbc_box, step) + + return coordinate_new, velocity_half, force, energy, kinetics, virial, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/velocityverlet.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/velocityverlet.py new file mode 100644 index 0000000000000000000000000000000000000000..a152b94c6870bcc81fc184339dcab4e58848f387 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/integrator/velocityverlet.py @@ -0,0 +1,146 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Velocity verlet integrator +""" + +import mindspore.numpy as msnp +from mindspore.ops import functional as F +from mindspore import Tensor, Parameter + +from .integrator import Integrator +from ..thermostat import Thermostat +from ..barostat import Barostat +from ..constraint import Constraint +from ...system import Molecule + + +class VelocityVerlet(Integrator): + r""" + A velocity verlet integrator based on "middle scheme" developed by Jian Liu, et al. + + Reference: + `Zhang, Z.; Liu, X.; Chen, Z.; Zheng, H.; Yan, K.; Liu, J. + A Unified Thermostat Scheme for Efficient Configurational Sampling for + Classical/Quantum Canonical Ensembles via Molecular Dynamics [J]. + The Journal of Chemical Physics, 2017, 147(3): 034109. + `_. + + Args: + system (Molecule): Simulation system. + thermostat (Thermostat): Thermostat for temperature coupling. Default: None + barostat (Barostat): Barostat for pressure coupling. Default: None + constraint (Constraint): Constraint algorithm. Default: None + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + system: Molecule, + thermostat: Thermostat = None, + barostat: Barostat = None, + constraint: Constraint = None, + ): + + super().__init__( + system=system, + thermostat=thermostat, + barostat=barostat, + constraint=constraint, + ) + + # v(t+0.5) = v(t) + 0.5 * a(t) * dt + velocity_half = msnp.zeros_like(self.system.coordinate) + self.velocity_half = Parameter(velocity_half, name='velocity_half') + + def set_velocity_half(self, velocity_half: Tensor, success: bool = True) -> bool: + """ + set the veloctiy before half step. + + Args: + velocity_half (Tensor): Tensor of velocity before half step. + success (Tensor): Whether the velocity has been set successfully. + """ + return F.depend(success, F.assign(self.velocity_half, velocity_half)) + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + acceleration = self.acc_unit_scale * force * self._inv_mass + + # if t > 0: v(t) = v(t-0.5) + 0.5 * a(t) * dt + velocity = msnp.where(step > 0, self.velocity_half + + 0.5 * acceleration * self.time_step, velocity) + # (B,A,D) = (B,A,D) - (B,1,D) + velocity -= self.get_com_velocity(velocity) + + # v(t+0.5) = v(t) + 0.5 * a(t) * dt + velocity_half = velocity + 0.5 * acceleration * self.time_step + + # R(t+0.5) = R(t) + 0.5 * v(t+0.5) * dt + coordinate_half = coordinate + velocity_half * self.time_step * 0.5 + + if self.thermostat is not None: + # v'(t) = f_T[v(t)] + kinetics = self.get_kinetics(velocity_half) + coordinate_half, velocity_half, force, energy, kinetics, virial, pbc_box = \ + self.thermostat(coordinate_half, velocity_half, + force, energy, kinetics, virial, pbc_box, step) + + # R(t+1) = R(t+0.5) + 0.5 * v'(t) * dt + coordinate_new = coordinate_half + velocity_half * self.time_step * 0.5 + + if self.constraint is not None: + for i in range(self.num_constraint_controller): + coordinate_new, velocity_half, force, energy, kinetics, virial, pbc_box = \ + self.constraint[i]( + coordinate_new, velocity_half, force, energy, kinetics, virial, pbc_box, step) + + if self.barostat is not None: + coordinate_new, velocity_half, force, energy, kinetics, virial, pbc_box = \ + self.barostat(coordinate_new, velocity_half, force, + energy, kinetics, virial, pbc_box, step) + + F.depend(True, F.assign(self.velocity_half, velocity_half)) + + kinetics = self.get_kinetics(velocity) + + return coordinate_new, velocity, force, energy, kinetics, virial, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d06310a7abb62373ee5f003a3bffb3b6ebbb3f76 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Thermostat""" + +from .thermostat import Thermostat +from .berendsen import BerendsenThermostat +from .langevin import Langevin + +__all__ = ['Thermostat', 'BerendsenThermostat', 'Langevin'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/berendsen.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/berendsen.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0e7c0b8b9793f0a5af80b0027dca21ca2a3cc6 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/berendsen.py @@ -0,0 +1,112 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Berendsen thermostat""" + +from mindspore import Tensor +from mindspore import ops + +from . import Thermostat +from ...system import Molecule + + +class BerendsenThermostat(Thermostat): + r""" + A Berendsen (weak coupling) thermostat controller. + + Reference: + `Berendsen, H. J. C.; Postma, J. P. M.; van Gunsteren, W. F.; DiNola, A.; Haak, J. R.. + Molecular Dynamics with Coupling to an External Bath [J]. + The Journal of Chemical Physics, 1984, 81(8): 3684. + `_. + + Args: + system (Molecule): Simulation system. + temperature (float): Reference temperature T_ref (K) for temperature coupling. + Default: 300 + control_step (int): Step interval for controller execution. Default: 1 + time_constant (float) Time constant \tau_T (ps) for temperature coupling. + Default: 4 + scale_min (float): The minimum value to clip the velocity scale factor. Default: 0.8 + scale_max (float): The maximum value to clip the velocity scale factor. Default: 1.25 + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + system: Molecule, + temperature: float = 300, + control_step: int = 1, + time_constant: float = 4, + scale_min: float = 0.8, + scale_max: float = 1.25, + ): + + super().__init__( + system=system, + temperature=temperature, + control_step=control_step, + time_constant=time_constant, + ) + + self.scale_min = scale_min + self.scale_max = scale_max + + self.ratio = self.control_step * self.time_step / self.time_constant + + def set_time_step(self, dt): + """ + set simulation time step. + + Args: + dt (float): Time of a time step. + """ + self.time_step = dt + self.ratio = self.control_step * self.time_step / self.time_constant + return self + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + if self.control_step == 1 or step % self.control_step == 0: + scale = self.velocity_scale(kinetics, self.ref_kinetics, self.ratio) + scale = ops.clip_by_value(scale, self.scale_min, self.scale_max) + velocity *= scale + + return coordinate, velocity, force, energy, kinetics, virial, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/langevin.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/langevin.py new file mode 100644 index 0000000000000000000000000000000000000000..39aa6cef8dc4996f453a71b56f007e862f1737e0 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/langevin.py @@ -0,0 +1,129 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Langevin thermostat""" + +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore import ops +from mindspore.ops import functional as F + +from .thermostat import Thermostat +from ...system import Molecule + + +class Langevin(Thermostat): + r""" + A Langevin thermostat controller. + + Reference: + `Goga, N.; Rzepiela, A. J.; de Vries, A. H.; Marrink, S. J.; Berendsen, H. J. C.. + Efficient Algorithms for Langevin and DPD Dynamics [J]. + Journal of Chemical Theory and Computation, 2012, 8(10): 3637-3649. + `_. + + Args: + system (Molecule): Simulation system. + temperature (float): Reference temperature T_ref (K) for temperature coupling. + Default: 300 + control_step (int): Step interval for controller execution. Default: 1 + time_constant (float): Time constant \tau_T (ps) for temperature coupling. + Default: 2 + seed (int): Random seed for standard normal. Default: 0 + seed2 (int): Random seed2 for standard normal. Default: 0 + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + system: Molecule, + temperature: float = 300, + control_step: int = 1, + time_constant: float = 2, + seed: int = 0, + seed2: int = 0, + ): + + super().__init__( + system=system, + temperature=temperature, + control_step=control_step, + time_constant=time_constant, + ) + + # (B,A,1) + self._inv_sqrt_mass = F.sqrt(self._inv_mass) + + # (B,1,1) + # \gamma = 1.0 / \tau_t + self.effective_friction_rate = msnp.reciprocal(self.time_constant) + # \f = 1 - exp(-\gamma * dt) + self.friction = 1.0 - \ + msnp.exp(-self.effective_friction_rate*self.time_step) + # k = \sqrt(f * (2 - f) * k_B * T) + self.random_scale = F.sqrt(self.friction * (2 - self.friction) * self.boltzmann * + self.ref_temp / self.kinetic_unit_scale) + + self.standard_normal = ops.StandardNormal(seed, seed2) + + def set_time_step(self, dt): + """ + set simulation time step. + + Args: + dt (float): Time of a time step. + """ + self.time_step = dt + # \f = 1 - exp(-\gamma * dt) + self.friction = 1.0 - \ + msnp.exp(-self.effective_friction_rate*self.time_step) + # k = \sqrt(f * (2 - f) * k_B * T) + self.random_scale = F.sqrt(self.friction * (2 - self.friction) * self.boltzmann * + self.ref_temp / self.kinetic_unit_scale) + return self + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + if self.control_step == 1 or step % self.control_step == 0: + velocity += -self.friction * velocity + self.random_scale * \ + self._inv_sqrt_mass * self.standard_normal(velocity.shape) + + return coordinate, velocity, force, energy, kinetics, virial, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/thermostat.py b/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/thermostat.py new file mode 100644 index 0000000000000000000000000000000000000000..fac26e33eb30c598e28ae312f4e8c81c2576aa4a --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/control/thermostat/thermostat.py @@ -0,0 +1,160 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Thermostat +""" + +import mindspore as ms +from mindspore import Tensor +from mindspore.ops import functional as F + +from .. import Controller +from ...system import Molecule +from ...function import functions as func + + +class Thermostat(Controller): + r""" + Thermostat controller for temperature coupling. + + Args: + system (Molecule): Simulation system. + temperature (float): Reference temperature T_ref (K) for temperature coupling. + Default: 300 + control_step (int): Step interval for controller execution. Default: 1 + time_constant (float) Time constant \tau_T (ps) for temperature coupling. + Default: 4 + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + - force (Tensor), Tensor of shape (B, A, D). Data type is float. + - energy (Tensor), Tensor of shape (B, 1). Data type is float. + - kinetics (Tensor), Tensor of shape (B, D). Data type is float. + - virial (Tensor), Tensor of shape (B, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + system: Molecule, + temperature: float = 300, + control_step: int = 1, + time_constant: float = 4., + ): + + super().__init__( + system=system, + control_step=control_step, + ) + + self.boltzmann = self.units.boltzmann + self.kinetic_unit_scale = self.units.kinetic_ref + + self.ref_temp = Tensor(temperature, ms.float32).reshape(-1, 1) + self.ref_kinetics = 0.5 * self.degrees_of_freedom * self.boltzmann * self.ref_temp + + # \tau_t + self.time_constant = Tensor(time_constant, ms.float32).reshape(-1, 1) + if self.time_constant.shape[0] != self.num_walker and self.time_constant.shape[0] != 1: + raise ValueError( + 'The first shape of time_constant must equal to 1 or num_walker') + + @property + def temperature(self): + """reference temperature.""" + return self.ref_temp + + @property + def kinetics(self): + """reference kinetics""" + return self.ref_kinetics + + def set_degrees_of_freedom(self, dofs: int): + """ + set degrees of freedom (DOFs). + + Args: + dofs (int): Degrees of freedom. + """ + self.degrees_of_freedom = dofs + self.ref_kinetics = 0.5 * self.degrees_of_freedom * self.boltzmann * self.ref_temp + return self + + def velocity_scale(self, sim_kinetics: Tensor, ref_kinetics: Tensor, ratio: float = 1) -> Tensor: + r""" + calculate the velocity scale factor for temperature coupling. + + Args: + sim_kinetics (Tensor): Tensor of simulation kinetics. + ref_kinetics (Tensor): Tensor of reference kinetics. + ratio (float): The degree of change lambda\_. + + Returns: + Tensor, teh velocity scale factor. + """ + sim_kinetics = func.keepdim_sum(sim_kinetics, -1) + lambda_ = 1. + ratio * (ref_kinetics / sim_kinetics - 1) + return F.sqrt(lambda_) + + def construct(self, + coordinate: Tensor, + velocity: Tensor, + force: Tensor, + energy: Tensor, + kinetics: Tensor, + virial: Tensor = None, + pbc_box: Tensor = None, + step: int = 0, + ): + + r""" + Control the temperature of the simulation system. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + velocity (Tensor): Tensor of shape (B, A, D). Data type is float. + force (Tensor): Tensor of shape (B, A, D). Data type is float. + energy (Tensor): Tensor of shape (B, 1). Data type is float. + kinetics (Tensor): Tensor of shape (B, D). Data type is float. + virial (Tensor): Tensor of shape (B, D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + step (int): Simulation step. Default: 0 + + Returns: + coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + force (Tensor), Tensor of shape (B, A, D). Data type is float. + energy (Tensor), Tensor of shape (B, 1). Data type is float. + kinetics (Tensor), Tensor of shape (B, D). Data type is float. + virial (Tensor), Tensor of shape (B, D). Data type is float. + pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Symbols: + B: Number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + """ + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31df558ca232d981d6324125b9077ca22c6f2d72 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Core codes of MindSPONGE""" + +from .sponge import Sponge +from .simulation import SimulationCell, RunOneStepCell +from .analysis import AnalyseCell +from .wrapper import EnergySummation + +__all__ = ['Sponge', 'SimulationCell', 'RunOneStepCell', 'AnalyseCell', 'EnergySummation'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/analysis/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1b00ba371dbdac8fadce8f451a64bc8430e34c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/analysis/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Analysis""" + +from .analyse import AnalyseCell + +__all__ = ['AnalyseCell'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/analysis/analyse.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/analysis/analyse.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbb990b5a27e9ae11c85842445e75866a6f1ba6 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/analysis/analyse.py @@ -0,0 +1,107 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Analyse Cell +""" + +import mindspore as ms +from mindspore import ops +from mindspore.nn import Cell +from mindspore.common import Tensor + +from ...system import Molecule +from ...potential import PotentialCell +from ...partition import NeighbourList + + +class AnalyseCell(Cell): + r""" + Core cell for analysis. + + Args: + system (Molecule): Simulation system. + potential (PotentialCell): Potential energy. + neighbour_list (NeighbourList): Neighbour list. Default: None + calc_energy (bool): Whether to calculate the energy. Default: False + calc_forces (bool): Whether to calculate the forces. Default: False + + Outputs: + - energy. + - forces. + - coordinates. + - pbc_box. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + system: Molecule, + potential: PotentialCell, + neighbour_list: NeighbourList = None, + calc_energy: bool = False, + calc_forces: bool = False, + ): + + super().__init__(auto_prefix=False) + + self.system = system + self.potential = potential + self.pbc_box = self.system.pbc_box + + self.neighbour_list = neighbour_list + if neighbour_list is None: + self.neighbour_list = NeighbourList(system) + + self.calc_energy = calc_energy + self.calc_forces = calc_forces + + self.system_units = self.system.units + self.potential_units = self.potential.units + + self.units = self.system.units + + self.input_unit_scale = Tensor(self.units.convert_length_to( + self.potential.length_unit()), ms.float32) + self.output_unit_scale = Tensor(self.units.convert_energy_from( + self.potential.energy_unit()), ms.float32) + + self.grad = ops.GradOperation() + + def construct(self, coordinates=None, pbc_box=None): + """analyse the system.""" + if coordinates is None: + coordinates, pbc_box = self.system() + + coordinates *= self.input_unit_scale + if self.pbc_box is not None: + pbc_box *= self.input_unit_scale + + energy = None + if self.calc_energy: + energy = self.potential(coordinates, pbc_box) + + forces = None + if self.calc_forces: + forces = -self.grad(self.potential)(coordinates, pbc_box) + + return energy, forces, coordinates, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..412ffb3fbd368efa3fd0c6ee7b0444b24f2fe7fa --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""simulation""" + +from .simulation import SimulationCell +from .run import RunOneStepCell + +__all__ = ['SimulationCell', 'RunOneStepCell'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/run.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/run.py new file mode 100644 index 0000000000000000000000000000000000000000..a45036a5f347e7ced4a360f851a68f7ec0473520 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/run.py @@ -0,0 +1,166 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +RunOneStepCell +""" + +from mindspore import ops +from mindspore.ops import functional as F +from mindspore import ms_function +from mindspore.nn import Cell + +from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, + _get_parallel_mode) +from mindspore.context import ParallelMode +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.nn.optim import Optimizer + +from .simulation import SimulationCell +from ...function.functions import get_integer +from ...optimizer import Updater + + +class RunOneStepCell(Cell): + r""" + Core cell to run one step simulation. + + Args: + network (SimulationCell): Network for simulation system. + optimizer (Optimizer): Optimizer for simulation. + steps (int): Steps for ms_function. Default: 1 + sens (float): The scaling number to be filled as the input of backpropagation. + Default: 1.0 + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + network: SimulationCell, + optimizer: Optimizer, + steps: int = 1, + sens: float = 1.0, + ): + + super().__init__(auto_prefix=False) + + self.network = network + self.network.set_grad() + self.optimizer = optimizer + self.neighbour_list = self.network.neighbour_list + self.update_neighbour_list = self.network.update_neighbour_list + + self.coordinate = self.network.coordinate + self.pbc_box = self.network.pbc_box + + self.use_updater = isinstance(self.optimizer, Updater) + self.weights = self.optimizer.parameters + + self.grad = ops.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = F.identity + self.parallel_mode = _get_parallel_mode() + self.reducer_flag = self.parallel_mode in ( + ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) + if self.reducer_flag: + self.mean = _get_gradients_mean() + self.degree = _get_device_num() + self.grad_reducer = DistributedGradReducer( + self.weights, self.mean, self.degree) + + self.steps = get_integer(steps) + + def set_pbc_grad(self, value: bool): + """ + set whether to calculate the gradient of PBC box. + + Args: + value (bool): Use to judge whether to calculate the gradient of PBC box. + """ + self.network.set_pbc_grad(value) + return self + + def set_steps(self, steps: int): + """ + set steps for ms_function. + + Args: + steps (int): steps of ms_function. + """ + self.steps = get_integer(steps) + return self + + @ms_function + def get_energy_and_force(self, *inputs): + """ + get energy and force of the system. + + Returns: + - energy (Tensor). + - force (Tensor). + """ + energy = self.network(*inputs) + sens = F.fill(energy.dtype, energy.shape, self.sens) + force = - self.grad(self.network, self.coordinate)(*inputs, sens) + return energy, force + + # @ms_function + def run_one_step(self, *inputs): + """ + run one step simulation. + + Returns: + - energy (Tensor), the result of simulation cell. + - force (Tensor), the result of simulation cell. + """ + energy = self.network(*inputs) + + sens = F.fill(energy.dtype, energy.shape, self.sens) + grads = self.grad(self.network, self.weights)(*inputs, sens) + + force = -grads[0] + + if self.use_updater: + energy = F.depend(energy, self.optimizer(grads, energy)) + else: + energy = F.depend(energy, self.optimizer(grads)) + + return energy, force + + def construct(self, *inputs): + """ + run simulation + + Returns: + - energy (Tensor), the result of simulation cell. + - force (Tensor), the result of simulation cell. + """ + if self.steps == 1: + return self.run_one_step(*inputs) + + energy = None + force = None + for _ in range(self.steps): + energy, force = self.run_one_step(*inputs) + + return energy, force diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/simulation.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..bb9e489b9eed8fa822afc8f1cf4b1a074627d201 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/simulation/simulation.py @@ -0,0 +1,264 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Simulation Cell +""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore import Parameter +from mindspore import context +from mindspore import ops, nn +from mindspore.ops import functional as F +from mindspore.nn import Cell, CellList + +from ...partition import NeighbourList +from ...system import Molecule +from ...potential import PotentialCell +from ...potential.bias import Bias +from ...function.functions import gather_vectors +from ...function.operations import GetVector +from ..wrapper import EnergyWrapper, get_energy_wrapper + + +class SimulationCell(Cell): + r""" + Core cell for simulation. + + Args: + system (Molecule): Simulation system. + potential (PotentialCell): Potential energy. + cutoff (float): Cutoff distance. Default: None + neighbour_list (NeighbourList): Neighbour list. Default: None + wrapper (EnergyWrapper): Network to wrap and process potential and bias. + Default: 'sum' + bias (Bias): Bias potential: Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + system: Molecule, + potential: PotentialCell, + cutoff: float = None, + neighbour_list: NeighbourList = None, + wrapper: EnergyWrapper = 'sum', + bias: Bias = None, + ): + + super().__init__(auto_prefix=False) + + self.system = system + self.potential = potential + + self.bias_network = None + self.num_bias = 0 + if bias is not None: + if isinstance(bias, list): + self.num_bias = len(bias) + self.bias_network = CellList(bias) + elif isinstance(bias, Cell): + self.num_bias = 1 + self.bias_network = CellList([bias]) + else: + raise TypeError('The "bias" must be Cell or list but got: '+str(type(bias))) + + self.num_walker = self.system.num_walker + self.num_atoms = self.system.num_atoms + + self.dim_potential = self.potential.output_dim + self.dim_bias = 0 + if self.bias_network is not None: + self.dim_bias = len(self.bias_network) + + self.energy_wrapper = get_energy_wrapper( + wrapper, + num_walker=self.num_walker, + dim_potential=self.dim_potential, + dim_bias=self.dim_bias, + ) + + self.exclude_index = self.potential.exclude_index + self.neighbour_list = neighbour_list + if neighbour_list is None: + self.neighbour_list = NeighbourList( + system, cutoff, exclude_index=self.exclude_index) + else: + self.neighbour_list.set_exclude_index(self.exclude_index) + + self.neighbour_index = self.neighbour_list.neighbours + self.neighbour_mask = self.neighbour_list.neighbour_mask + + self.no_mask = False + if context.get_context("mode") == context.PYNATIVE_MODE and self.neighbour_list.no_mask: + self.no_mask = True + + self.num_neighbours = self.neighbour_list.num_neighbours + + self.cutoff = self.neighbour_list.cutoff + if self.cutoff is not None: + self.potential.set_cutoff(self.cutoff) + self.nl_update_steps = self.neighbour_list.update_steps + + self.coordinate = self.system.coordinate + self.pbc_box = self.system.pbc_box + self.atom_mass = self.system.atom_mass + + self.pbc_box = self.system.pbc_box + use_pbc = self.pbc_box is not None + + self.potential.set_pbc(use_pbc) + + for p in self.potential.trainable_params(): + p.requires_grad = False + + self.units = self.system.units + + self.potential_units = self.potential.units + + self.input_unit_scale = Tensor(self.units.convert_length_to( + self.potential.length_unit), ms.float32) + self.output_unit_scale = Tensor(self.units.convert_energy_from( + self.potential.energy_unit), ms.float32) + + self.get_vector = GetVector(use_pbc) + + mask_fill = self.units.length(10, 'nm') + self.mask_fill = Tensor(mask_fill, ms.float32) + + self.identity = ops.Identity() + + self.bias = None + if self.bias_network is not None: + self.bias = Parameter(msnp.zeros((self.num_walker, self.num_bias), dtype=ms.float32), + name='bias_potential', requires_grad=False) + + self.norm_last_dim = nn.Norm(axis=-1, keep_dims=False) + + self.norm_last_dim = nn.Norm(axis=-1, keep_dims=False) + + @property + def length_unit(self): + return self.units.length_unit + + @property + def energy_unit(self): + return self.units.energy_unit + + def set_pbc_grad(self, grad_box: bool): + """ + set whether to calculate the gradient of PBC box. + + Args: + grad_box (bool): Whether to calculate the gradient of PBC box. + """ + self.system.set_pbc_grad(grad_box) + return self + + def update_neighbour_list(self): + """update neighbour list.""" + coordinate, pbc_box = self.system() + return self.neighbour_list(coordinate, pbc_box) + + def get_neighbour_list(self): + """ + get neighbour list. + + Returns: + - neighbour_index (Tensor). + - neighbour_mask (Tensor). + """ + neighbour_index, neighbour_mask = self.neighbour_list.get_neighbour_list() + return neighbour_index, neighbour_mask + + def construct(self, *inputs): + """ + calculate the energy of system. + + Returns: + - energy (Tensor). + - force (Tensor). + """ + #pylint: disable=unused-argument + coordinate, pbc_box = self.system() + + coordinate *= self.input_unit_scale + if pbc_box is not None: + pbc_box *= self.input_unit_scale + + neighbour_index, neighbour_mask = self.get_neighbour_list() + + # (B,A,1,D) <- (B,A,D): + atoms = F.expand_dims(coordinate, -2) + # (B,A,N,D) <- (B,A,D): + neighbour_coord = gather_vectors(coordinate, neighbour_index) + neighbour_vector = self.get_vector(atoms, neighbour_coord, pbc_box) + + # Add a non-zero value to the neighbour_vector whose mask value is False + # to prevent them from becoming zero values after Norm operation, + # which could lead to auto-differentiation errors + if neighbour_mask is not None: + # (B,A,N): + mask_fill = msnp.where(neighbour_mask, 0, self.mask_fill) + # (B,A,N,D) = (B,A,N,D) + (B,A,N,1) + neighbour_vector += F.expand_dims(mask_fill, -1) + + # (B,A,N) = (B,A,N,D): + neighbour_distance = self.norm_last_dim(neighbour_vector) + + if self.cutoff is not None: + distance_mask = neighbour_distance < self.cutoff + if neighbour_mask is None: + neighbour_mask = distance_mask + else: + neighbour_mask = F.logical_and(distance_mask, neighbour_mask) + + potential = self.potential( + coordinate=coordinate, + neighbour_index=neighbour_index, + neighbour_mask=neighbour_mask, + neighbour_coord=neighbour_coord, + neighbour_distance=neighbour_distance, + pbc_box=pbc_box + ) * self.output_unit_scale + + bias = None + if self.bias_network is not None: + bias = () + for i in range(self.num_bias): + bias_ = self.bias_network[i]( + coordinate=coordinate, + neighbour_index=neighbour_index, + neighbour_mask=neighbour_mask, + neighbour_coord=neighbour_coord, + neighbour_distance=neighbour_distance, + pbc_box=pbc_box + ) + bias += (bias_,) + + bias = msnp.concatenate(bias, axis=-1) * self.output_unit_scale + F.depend(potential, F.assign(self.bias, bias)) + + return self.energy_wrapper(potential, bias) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/sponge.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/sponge.py new file mode 100644 index 0000000000000000000000000000000000000000..64fa91a64f3133d074e42f2c903eee3fed745e04 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/sponge.py @@ -0,0 +1,549 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/ ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Core engine of MindSPONGE +""" + +import os +from typing import Union +import time +from collections.abc import Iterable + +from mindspore import nn +from mindspore.ops import functional as F +from mindspore.common import Tensor +from mindspore.nn.optim import Optimizer + +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.train.callback import Callback, RunContext, _InternalCallbackParam, _CallbackManager +from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ + _get_parameter_broadcast, _device_number_check +from mindspore.parallel._ps_context import _is_role_pserver +from mindspore.train.model import _StepSync, _transfer_tensor_to_tuple +from mindspore.nn.metrics import get_metrics, Metric +from mindspore.dataset.engine.datasets import Dataset + +from .simulation import SimulationCell +from .simulation import RunOneStepCell +from .analysis import AnalyseCell +from ..potential import PotentialCell +from ..optimizer import Updater, DynamicUpdater +from ..system.molecule import Molecule + + +class Sponge(): + r""" + Core engine of MindSPONGE. + + Args: + network (Union[Molecule, SimulationCell, RunOneStepCell]): Function or neural netork for simulation system. + potential (Cell): Potential energy. Default: None + optimizer (Optimizer): Optimizer. Default: None + metrics (Metric): Metrics. Default: None + analyse_network (Cell): Analyse network. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + network: Union[Molecule, SimulationCell, RunOneStepCell], + potential: PotentialCell = None, + optimizer: Optimizer = None, + metrics: Metric = None, + analyse_network: AnalyseCell = None, + ): + + self._potential = potential + self._optimizer = optimizer + self._metrics = metrics + self._analyse_network = analyse_network + + self._parallel_mode = _get_parallel_mode() + self._device_number = _get_device_num() + self._global_rank = _get_global_rank() + self._parameter_broadcast = _get_parameter_broadcast() + self._create_time = int(time.time() * 1e9) + + if self._potential is None and self._optimizer is None: + self._optimizer = None + self.sim_network: RunOneStepCell = network + self.sim_system: SimulationCell = self.sim_network.network + self._optimizer: Optimizer = self.sim_network.optimizer + self._system: Molecule = self.sim_system.network + self._potential: PotentialCell = self.sim_system.potential + else: + if self._optimizer is None: + raise ValueError( + '"optimizer" cannot be "None" when potential is not None!') + if self._potential is None: + self.sim_system: SimulationCell = network + self.sim_network = RunOneStepCell( + self.sim_system, self._optimizer) + self._system: Molecule = self.sim_system.system + self._potential: PotentialCell = self.sim_system.potential + else: + self._system: Molecule = network + self.sim_system = SimulationCell(self._system, self._potential) + self.sim_network = RunOneStepCell(self.sim_system, self._optimizer) + + self._check_for_graph_cell() + + self.use_updater = False + if isinstance(self._optimizer, Updater): + self.use_updater = True + + if isinstance(self.sim_network, RunOneStepCell): + self.sim_network.set_pbc_grad(self.use_updater) + + self.units = self._system.units + + self.time_step = self._optimizer.learning_rate.asnumpy() + + self.coordinate = self._system.coordinate + self.pbc_box = self._system.pbc_box + self.neighbour_list = self.sim_system.neighbour_list + + self.cutoff = self.neighbour_list.cutoff + self.nl_update_steps = self.neighbour_list.update_steps + + # Avoiding the bug for return None type + self.one_neighbour_terms = False + if self.neighbour_list.no_mask and context.get_context("mode") == context.PYNATIVE: + self.one_neighbour_terms = True + + self._metric_fns = None + if metrics is not None: + self._metric_fns = get_metrics(metrics) + + self.analyse_network = analyse_network + if analyse_network is None and self._metric_fns is not None: + self.analyse_network = AnalyseCell( + self._system, self._potential, self.neighbour_list) + + self.sim_step = 0 + self.sim_time = 0.0 + + def change_optimizer(self, optimizer: Optimizer): + """ + change optimizer. + + Args: + optimizer (Optimizer): Optimizer will be used. + """ + if self._optimizer is None: + raise ValueError('Cannot change the optimizer, because the initial optimizer is None ' + 'or the network is not a RunOneStepCell type.') + + self._optimizer = optimizer + + if isinstance(self._optimizer, Updater): + self.use_updater = True + else: + self.use_updater = False + + self.sim_network = RunOneStepCell(self.sim_system, self._optimizer) + self.sim_network.set_pbc_grad(self.use_updater) + + self.time_step = self._optimizer.learning_rate.asnumpy() + + return self + + def change_potential(self, potential: PotentialCell): + """ + change potential energy. + + Args: + potential (PotentialCell): Potential energy will be used. + """ + if self._potential is None: + raise ValueError('Cannot change the potential, because the initial potential is None ' + 'or the network is not a SimulationCell type.') + if self._optimizer is None: + raise ValueError('Cannot change the potential, because the initial optimizer is None ' + 'or the network is not a RunOneStepCell type.') + + self._potential = potential + self.sim_system = SimulationCell(self._system, self._potential) + self.sim_network = RunOneStepCell(self.sim_system, self._optimizer) + self.sim_network.set_pbc_grad(self.use_updater) + + return self + + def run(self, + steps: int, + callbacks: Callback = None, + dataset: Dataset = None + ): + """ + Run simulation. + + Args: + steps (int): Simulation steps. + callbacks (Callback): Callback functions. Default: None + dataset (Dataset): Dataset used at simulation process. Default: None + + """ + if self.cutoff is None or steps < self.nl_update_steps: + epoch = 1 + cycle_steps = steps + rest_steps = 0 + else: + epoch = steps // self.nl_update_steps + cycle_steps = self.nl_update_steps + rest_steps = steps - epoch * cycle_steps + + cb_params = _InternalCallbackParam() + cb_params.sim_network = self.sim_network + cb_params.num_steps = steps + + cb_params.num_steps = steps + cb_params.time_step = self.time_step + cb_params.num_epoch = epoch + cb_params.cycle_steps = cycle_steps + cb_params.rest_steps = rest_steps + cb_params.cutoff = self.cutoff + + cb_params.mode = "simulation" + cb_params.sim_network = self.sim_network + cb_params.system = self._system + cb_params.potential = self._potential + cb_params.optimizer = self._optimizer + cb_params.parallel_mode = self._parallel_mode + cb_params.device_number = self._device_number + cb_params.simulation_dataset = dataset + cb_params.list_callback = self._transform_callbacks(callbacks) + if context.get_context("mode") == context.PYNATIVE_MODE: + cb_params.list_callback.insert(0, _StepSync()) + callbacks = cb_params.list_callback + + cb_params.coordinate = self.coordinate + cb_params.pbc_box = self.pbc_box + cb_params.volume = self._system.get_volume() + if self.use_updater: + self._optimizer.set_step(0) + cb_params.velocity = self._optimizer.velocity + kinetics = F.reduce_sum(self._optimizer.kinetics, -1) + cb_params.kinetics = kinetics + cb_params.temperature = self._optimizer.temperature + pressure = self._optimizer.pressure + if pressure is not None: + # (B) <- (B,D) + pressure = F.reduce_mean(pressure, -1) + cb_params.pressure = pressure + + cb_params.thermostat = None + cb_params.barostat = None + cb_params.constraint = None + if isinstance(self._optimizer, DynamicUpdater): + cb_params.thermostat = self._optimizer.thermostat + cb_params.barostat = self._optimizer.barostat + cb_params.constraint = self._optimizer.constraint + + # build callback list + with _CallbackManager(callbacks) as list_callback: + self._simulation_process( + epoch, cycle_steps, rest_steps, list_callback, cb_params) + + return self + + def energy(self): + """get energy of system""" + return self.sim_system() + + def energy_and_force(self): + """get energy and force""" + return self.sim_network.get_energy_and_force() + + def analyse(self, dataset=None, callbacks=None): + """ + Evaluation API where the iteration is controlled by python front-end. + + Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode. + + Note: + If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features + of data will be transferred one by one. The limitation of data transmission per time is 256M. + When dataset_sink_mode is True, the step_end method of the Callback class will be executed when + the epoch_end method is called. + + Args: + dataset (Dataset): Dataset to evaluate the model. + callbacks (Optional[list(Callback)]): List of callback objects which should be executed + while training. Default: None. + + Returns: + Dict, the key is the metric name defined by users and the value is the metrics value for + the model in the test mode. + + Examples: + >>> from mindspore import Model, nn + >>> + >>> # For details about how to build the dataset, please refer to the tutorial + >>> # document on the official website. + >>> dataset = create_custom_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) + >>> acc = model.eval(dataset, dataset_sink_mode=False) + """ + + _device_number_check(self._parallel_mode, self._device_number) + if not self._metric_fns: + raise ValueError("The model argument 'metrics' can not be None or empty, " + "you should set the argument 'metrics' for model.") + + cb_params = _InternalCallbackParam() + cb_params.analyse_network = self.analyse_network + if dataset is not None: + cb_params.analysis_dataset = dataset + cb_params.batch_num = dataset.get_dataset_size() + cb_params.mode = "analyse" + cb_params.cur_step_num = 0 + + cb_params.list_callback = self._transform_callbacks(callbacks) + + self._clear_metrics() + + with _CallbackManager(callbacks) as list_callback: + return self._analyse_process(dataset, list_callback, cb_params) + + def _check_for_graph_cell(self): + """Check for graph cell""" + if not isinstance(self._system, nn.GraphCell): + return + + if self._potential is not None or self._optimizer is not None: + raise ValueError("For 'Model', 'loss_fn' and 'optimizer' should be None when network is a GraphCell, " + "but got 'loss_fn': {}, 'optimizer': {}.".format(self._potential, self._optimizer)) + + @staticmethod + def _transform_callbacks(callbacks: Callback): + """Transform callback to a list.""" + if callbacks is None: + return [] + + if isinstance(callbacks, Iterable): + return list(callbacks) + + return [callbacks] + + def _simulation_process(self, + epoch: int, + cycle_steps: int, + rest_steps: int, + list_callback: Callback = None, + cb_params: _InternalCallbackParam = None + ): + """ + Training process. The data would be passed to network directly. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be + returned and passed to the network. Otherwise, a tuple (data, label) should + be returned. The data and label would be passed to the network and loss + function respectively. + list_callback (Callback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + """ + self._exec_preprocess(True) + + self.sim_step = 0 + self.sim_time = 0.0 + run_context = RunContext(cb_params) + list_callback.begin(run_context) + # used to stop training for early stop, such as stopAtTIme or stopATStep + should_stop = False + + for i in range(epoch): + cb_params.cur_epoch = i + if self.pbc_box is None: + coordinate = self._system() + pbc_box = None + else: + coordinate, pbc_box = self._system() + self.neighbour_list(coordinate, pbc_box) + should_stop = self._run_one_epoch( + cycle_steps, list_callback, cb_params, run_context) + should_stop = should_stop or run_context.get_stop_requested() + if should_stop: + break + + if rest_steps > 0: + if self.pbc_box is None: + coordinate = self._system() + pbc_box = None + else: + coordinate, pbc_box = self._system() + self.neighbour_list(coordinate, pbc_box) + self._run_one_epoch(rest_steps, list_callback, + cb_params, run_context) + + list_callback.end(run_context) + + def _run_one_epoch(self, + cycles: int, + list_callback: Callback, + cb_params: _InternalCallbackParam, + run_context: RunContext + ): + """run one epoch simulation""" + should_stop = False + list_callback.epoch_begin(run_context) + for _ in range(cycles): + + cb_params.cur_step = self.sim_step + cb_params.cur_time = self.sim_time + list_callback.step_begin(run_context) + + cb_params.volume = self._system.get_volume() + + energy, force = self.sim_network() + + cb_params.energy = energy + cb_params.force = force + + if self.use_updater: + cb_params.velocity = self._optimizer.velocity + # (B) <- (B,D) + kinetics = F.reduce_sum(self._optimizer.kinetics, -1) + cb_params.kinetics = kinetics + cb_params.temperature = self._optimizer.temperature + pressure = self._optimizer.pressure + if pressure is not None: + # (B) <- (B,D) + pressure = F.reduce_mean(pressure, -1) + cb_params.pressure = pressure + + self.sim_step += 1 + self.sim_time += self.time_step + + list_callback.step_end(run_context) + + #pylint: disable = protected-access + if _is_role_pserver(): + os._exit(0) + should_stop = should_stop or run_context.get_stop_requested() + if should_stop: + break + + # if param is cache enable, flush data from cache to host before epoch end + self._flush_from_cache(cb_params) + + list_callback.epoch_end(run_context) + return should_stop + + def _analyse_process(self, dataset=None, list_callback=None, cb_params=None): + """ + Evaluation. The data would be passed to network directly. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. + list_callback (Callback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + + Returns: + Dict, which returns the loss value and metrics values for the model in the test mode. + """ + run_context = RunContext(cb_params) + list_callback.begin(run_context) + dataset_helper, _ = self._exec_preprocess(False) + list_callback.epoch_begin(run_context) + + if dataset is None: + cb_params.cur_step_num += 1 + list_callback.step_begin(run_context) + outputs = self.analyse_network() + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + self._update_metrics(outputs) + else: + for next_element in dataset_helper: + cb_params.cur_step_num += 1 + list_callback.step_begin(run_context) + next_element = _transfer_tensor_to_tuple(next_element) + outputs = self.analyse_network(*next_element) + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + self._update_metrics(outputs) + + list_callback.epoch_end(run_context) + dataset.reset() + metrics = self._get_metrics() + cb_params.metrics = metrics + list_callback.end(run_context) + return metrics + + def _clear_metrics(self): + """Clear metrics local values.""" + for metric in self._metric_fns.values(): + metric.clear() + + def _update_metrics(self, outputs): + """Update metrics local values.""" + if isinstance(outputs, Tensor): + outputs = (outputs,) + if not isinstance(outputs, tuple): + raise ValueError( + f"The argument 'outputs' should be tuple, but got {type(outputs)}.") + + for metric in self._metric_fns.values(): + metric.update(*outputs) + + def _get_metrics(self): + """Get metrics local values.""" + metrics = dict() + for key, value in self._metric_fns.items(): + metrics[key] = value.eval() + return metrics + + def _exec_preprocess(self, is_run): + """Initializes dataset.""" + if is_run: + network = self.sim_network + phase = 'simulation' + else: + network = self.analyse_network + phase = 'analyse' + + network.set_train(is_run) + network.phase = phase + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() + + return network + + def _flush_from_cache(self, cb_params): + """Flush cache data to host if tensor is cache enable.""" + params = cb_params.sim_network.get_parameters() + for param in params: + if param.cache_enable: + Tensor(param).flush_from_cache() + + @property + def create_time(self): + return self._create_time \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eca53f41f3405fcd7ba4510b744cc8705a0a2954 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Energy wrapper""" + +from .wrapper import EnergyWrapper, get_energy_wrapper +from .summation import EnergySummation + +__all__ = ['EnergyWrapper', 'get_energy_wrapper', 'EnergySummation'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/its.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/its.py new file mode 100644 index 0000000000000000000000000000000000000000..b4eaa7912b8a9aa1ecc4190ca82a698f2c7869cd --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/its.py @@ -0,0 +1,77 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Integrated tempering sampling (ITS)""" + +from mindspore import Tensor + +from .wrapper import EnergyWrapper +from .wrapper import _energy_wrapper_register + + +@_energy_wrapper_register('its') +class IntegratedTemperingSampling(EnergyWrapper): + r"""TODO: Integrated tempering sampling (ITS). + + Args: + + num_walker (int): Number of multiple walker (B). Default: 1 + + dim_potential (int): Dimension of potential energy (U). Default: 1 + + dim_bias (int): Dimension of bias potential (V). Default: 1 + + """ + + def __init__(self, + num_walker: int = 1, + dim_potential: int = 1, + dim_bias: int = 1, + ): + + super().__init__( + num_walker=num_walker, + dim_potential=dim_potential, + dim_bias=dim_bias, + ) + + def construct(self, potential: Tensor, bias: Tensor = None): + """merge the potential and bias. + + Args: + potential (Tensor): Tensor of shape (B, U). Data type is float. + Potential energy. + bias (Tensor): Tensor of shape (B, V). Data type is float. + Bias potential. Default: None + + Return: + energy (Tensor): Tensor of shape (B, 1). Data type is float. + Total energy. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + U: Dimension of potential energy. + V: Dimension of bias potential. + + """ + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/remd.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/remd.py new file mode 100644 index 0000000000000000000000000000000000000000..05ec47bfb0a9c52a4f7afd363476417952e39f5b --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/remd.py @@ -0,0 +1,77 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Replica exchange molecular dynamics (REMD) """ + +from mindspore import Tensor + +from .wrapper import EnergyWrapper +from .wrapper import _energy_wrapper_register + + +@_energy_wrapper_register('remd') +class ReplicaExchange(EnergyWrapper): + r"""TODO: Replica exchange molecular dynamics (REMD). + + Args: + + num_walker (int): Number of multiple walker (B). Default: 1 + + dim_potential (int): Dimension of potential energy (U). Default: 1 + + dim_bias (int): Dimension of bias potential (V). Default: 1 + + """ + + def __init__(self, + num_walker: int = 1, + dim_potential: int = 1, + dim_bias: int = 1, + ): + + super().__init__( + num_walker=num_walker, + dim_potential=dim_potential, + dim_bias=dim_bias, + ) + + def construct(self, potential: Tensor, bias: Tensor = None): + """merge the potential and bias. + + Args: + potential (Tensor): Tensor of shape (B, U). Data type is float. + Potential energy. + bias (Tensor): Tensor of shape (B, V). Data type is float. + Bias potential. Default: None + + Return: + energy (Tensor): Tensor of shape (B, 1). Data type is float. + Total energy. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + U: Dimension of potential energy. + V: Dimension of bias potential. + + """ + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/summation.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/summation.py new file mode 100644 index 0000000000000000000000000000000000000000..adb38b8bd0afdb0c5fbc53356b47e754aecf038f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/summation.py @@ -0,0 +1,82 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Energy wrapper""" + +from mindspore import Tensor + +from .wrapper import EnergyWrapper +from .wrapper import _energy_wrapper_register + + +@_energy_wrapper_register('sum') +class EnergySummation(EnergyWrapper): + r""" + A network to sum the potential and bias directly. + + Args: + num_walker (int): Number of multiple walker (B). Default: 1 + dim_potential (int): Dimension of potential energy (U). Default: 1 + dim_bias (int): Dimension of bias potential (V). Default: 1 + + Outputs: + energy (Tensor), Tensor of shape (B, 1). Data type is float. Total energy. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + num_walker: int = 1, + dim_potential: int = 1, + dim_bias: int = 1, + ): + + super().__init__( + num_walker=num_walker, + dim_potential=dim_potential, + dim_bias=dim_bias, + ) + + def construct(self, potential: Tensor, bias: Tensor = None): + """merge the potential and bias. + + Args: + potential (Tensor): Tensor of shape (B, U). Data type is float. + Potential energy. + bias (Tensor): Tensor of shape (B, V). Data type is float. + Bias potential. Default: None + + Return: + energy (Tensor), Tensor of shape (B, 1). Data type is float. Total energy. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + U: Dimension of potential energy. + V: Dimension of bias potential. + """ + + potential = self.sum_last_dim(potential) + if bias is None: + return potential + + bias = self.sum_last_dim(bias) + return potential + bias diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/wrapper.py b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..7fafae37c119104386f107f22f762439ee87b7e4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/core/wrapper/wrapper.py @@ -0,0 +1,120 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Energy wrapper""" + +from mindspore import Tensor +from mindspore import ops +from mindspore.nn import Cell + +from ...function import get_integer + +_ENERGY_WRAPPER_BY_KEY = dict() + + +def _energy_wrapper_register(*aliases): + """Return the alias register.""" + def alias_reg(cls): + name = cls.__name__ + name = name.lower() + if name not in _ENERGY_WRAPPER_BY_KEY: + _ENERGY_WRAPPER_BY_KEY[name] = cls + + for alias in aliases: + if alias not in _ENERGY_WRAPPER_BY_KEY: + _ENERGY_WRAPPER_BY_KEY[alias] = cls + + return cls + + return alias_reg + + +class EnergyWrapper(Cell): + r"""A network to process and merge the potential and bias during the simulation. + + Args: + + num_walker (int): Number of multiple walker (B). Default: 1 + + dim_potential (int): Dimension of potential energy (U). Default: 1 + + dim_bias (int): Dimension of bias potential (V). Default: 1 + + """ + def __init__(self, + num_walker: int = 1, + dim_potential: int = 1, + dim_bias: int = 1, + ): + + super().__init__(auto_prefix=False) + + self.num_walker = get_integer(num_walker) + self.dim_potential = get_integer(dim_potential) + self.dim_bias = get_integer(dim_bias) + + self.concat_last_dim = ops.Concat(-1) + self.sum_last_dim = ops.ReduceSum(keep_dims=True) + + def construct(self, potential: Tensor, bias: Tensor = None): + """merge the potential and bias. + + Args: + potential (Tensor): Tensor of shape (B, U). Data type is float. + Potential energy. + bias (Tensor): Tensor of shape (B, V). Data type is float. + Bias potential. Default: None + + Return: + energy (Tensor): Tensor of shape (B, 1). Data type is float. + Total energy. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + U: Dimension of potential energy. + V: Dimension of bias potential. + + """ + raise NotImplementedError + + +def get_energy_wrapper(wrapper: str, + num_walker: int, + dim_potential: int, + dim_bias: int, + ) -> EnergyWrapper: + """get energy wrapper by name""" + if wrapper is None or isinstance(wrapper, EnergyWrapper): + return wrapper + if isinstance(wrapper, str): + if wrapper.lower() == 'none': + return None + if wrapper.lower() in _ENERGY_WRAPPER_BY_KEY.keys(): + return _ENERGY_WRAPPER_BY_KEY.get(wrapper.lower())( + num_walker=num_walker, + dim_potential=dim_potential, + dim_bias=dim_bias, + ) + raise ValueError( + "The energy wrapper corresponding to '{}' was not found.".format(wrapper)) + raise TypeError( + "Unsupported energy wrapper type '{}'.".format(type(wrapper))) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e6dc39e0f435f75ee6f775b538dd87b625371d63 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Data""" + +from .elements import elements, element_dict, element_name, element_set, atomic_mass +from .hyperparam import str_to_tensor, tensor_to_str +from .hyperparam import get_class_parameters, get_hyper_parameter, get_hyper_string +from .hyperparam import set_class_parameters, set_hyper_parameter, set_class_into_hyper_param +from .hyperparam import load_checkpoint, load_hyperparam, load_hyper_param_into_class +from .template import get_template, get_template_index, get_molecule +from .parameters import ForceFieldParameters +from .forcefield import get_forcefield +from .data import read_yaml, write_yaml, update_dict +from .data import get_bonded_types, get_dihedral_types, get_improper_types +from .data_transform import atom37_to_frames, atom37_to_torsion_angles + +__all__ = ['elements', 'element_dict', 'element_name', 'element_set', 'atomic_mass', + 'str_to_tensor', 'tensor_to_str', 'get_class_parameters', 'get_hyper_parameter', + 'get_hyper_string', 'set_class_parameters', 'set_hyper_parameter', + 'set_class_into_hyper_param', 'load_checkpoint', 'load_hyperparam', + 'load_hyper_param_into_class', 'get_template', 'get_template_index', + 'get_molecule', 'ForceFieldParameters', 'get_forcefield', 'read_yaml', + 'write_yaml', 'update_dict', 'get_bonded_types', 'get_dihedral_types', + 'get_improper_types', "atom37_to_frames", "atom37_to_torsion_angles"] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/data.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/data.py new file mode 100644 index 0000000000000000000000000000000000000000..a560a483e36c1523efc8920b70ad91af7d93ce1e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/data.py @@ -0,0 +1,182 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Base function for yaml +""" + +from itertools import permutations +import yaml +import numpy as np +from numpy import ndarray + + +def update_dict(origin_dict: dict, new_dict: dict) -> dict: + """ + update complex dict. + + Args: + origin_dict(dict): The original input dict need to be updated. + new_dict(dict): Complex dict will be updated according to new dict. + + Returns: + dict, update complex dict. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if new_dict is None: + return origin_dict + dictionary = origin_dict.copy() + origin_dict.update() + for k, v in new_dict.items(): + if k in dictionary.keys() and isinstance(dictionary.get(k), dict) and isinstance(v, dict): + dictionary[k] = update_dict(dictionary[k], v) + else: + dictionary[k] = v + return dictionary + + +def write_yaml(filename: str, data: dict): + """ + write yaml file. + + Args: + filename(str): name of yaml file. + data(dict): A dict of data. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + with open(filename, 'w', encoding="utf-8") as file: + yaml.dump(data, file, sort_keys=False) + + +def read_yaml(filename: str) -> dict: + """ + read yaml file. + + Args: + filename(str): the name of yaml file. + + Returns: + data(dict), data in the yaml file. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + with open(filename, 'r', encoding="utf-8") as file: + data = yaml.safe_load(file.read()) + return data + + +def get_bonded_types(atom_types: ndarray, symbol: str = '-'): + """ + get the types of bonded terms including bond, angle and dihedral. + + Args: + atom_types(ndarray): types of atoms. + symbol(str): a symbol. + + Returns: + types(ndarray), types of bonded terms including bond, angle and dihedral. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + num_atoms = atom_types.shape[-1] + + if num_atoms == 1: + return atom_types + + types = atom_types[..., 0] + for i in range(1, num_atoms): + types = np.char.add(types, symbol) + types = np.char.add(types, atom_types[..., i]) + + return types + + +def get_dihedral_types(atom_types: ndarray, symbol: str = '-'): + """ + The multi atom name constructor. + + Args: + atom_types(ndarray): types of atoms. + symbol(str): a symbol. + + Returns: + - types, ndarray. + - inverse_types, ndarray. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + num_atoms = atom_types.shape[-1] + + if num_atoms == 1: + return atom_types + + types = atom_types[..., 0] + for i in range(1, num_atoms): + types = np.char.add(types, symbol) + types = np.char.add(types, atom_types[..., i]) + + inverse_types = atom_types[..., -1] + for i in range(1, num_atoms): + inverse_types = np.char.add(inverse_types, symbol) + inverse_types = np.char.add(inverse_types, atom_types[..., -1-i]) + + return types, inverse_types + + +def get_improper_types(atom_types: ndarray, symbol: str = '-'): + """ + The multi atom name constructor. + + Args: + atom_types(ndarray): types of atoms. + symbol(str): a symbol. + + Returns: + - permuation_types, tuple. + - orders, tuple. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + num_atoms = atom_types.shape[-1] + + if num_atoms == 1: + return atom_types + + permuation_types = () + orders = () + for combination in permutations(range(num_atoms)): + types = atom_types[..., combination[0]] + for i in range(1, num_atoms): + types = np.char.add(types, symbol) + types = np.char.add(types, atom_types[..., combination[i]]) + permuation_types += (types,) + orders += (combination,) + + return permuation_types, orders diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/data_transform.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/data_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb95dec904a3811d238f7aa56eb17fb552bc61f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/data_transform.py @@ -0,0 +1,1136 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""data transform MSA TEMPLATE""" +import numpy as np +from scipy.special import softmax +from ..common import geometry as geometry +from ..common.residue_constants import chi_angles_mask, chi_pi_periodic, restype_1to3, chi_angles_atoms, \ + atom_order, residue_atom_renaming_swaps, restype_3to1, MAP_HHBLITS_AATYPE_TO_OUR_AATYPE, restype_order, \ + restypes, restype_name_to_atom14_names, atom_types, residue_atoms, STANDARD_ATOM_MASK, restypes_with_x_and_gap, \ + MSA_PAD_VALUES + +MS_MIN32 = -2147483648 +MS_MAX32 = 2147483647 + + +def one_hot(depth, indices): + """one hot compute""" + res = np.eye(depth)[indices.reshape(-1)] + return res.reshape(list(indices.shape) + [depth]) + + +def correct_msa_restypes(msa, deletion_matrix=None, is_evogen=False): + """Correct MSA restype to have the same order as residue_constants.""" + new_order_list = MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = np.array(new_order_list, dtype=msa.dtype) + msa = new_order[msa] + if is_evogen: + msa_input = np.concatenate((msa, deletion_matrix), axis=-1).astype(np.int32) + result = msa, msa_input + else: + result = msa + return result + + +def randomly_replace_msa_with_unknown(msa, aatype, replace_proportion): + """Replace a proportion of the MSA with 'X'.""" + msa_mask = np.random.uniform(size=msa.shape, low=0, high=1) < replace_proportion + x_idx = 20 + gap_idx = 21 + msa_mask = np.logical_and(msa_mask, msa != gap_idx) + msa = np.where(msa_mask, np.ones_like(msa) * x_idx, msa) + aatype_mask = np.random.uniform(size=aatype.shape, low=0, high=1) < replace_proportion + aatype = np.where(aatype_mask, np.ones_like(aatype) * x_idx, aatype) + return msa, aatype + + +def fix_templates_aatype(template_aatype): + """Fixes aatype encoding of templates.""" + # Map one-hot to indices. + template_aatype = np.argmax(template_aatype, axis=-1).astype(np.int32) + # Map hhsearch-aatype to our aatype. + new_order_list = MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = np.array(new_order_list, np.int32) + template_aatype = new_order[template_aatype] + return template_aatype + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + """compute pseudo beta features from atom positions""" + is_gly = np.equal(aatype, restype_order['G']) + ca_idx = atom_order['CA'] + cb_idx = atom_order['CB'] + pseudo_beta = np.where( + np.tile(is_gly[..., None].astype("int32"), [1] * len(is_gly.shape) + [3]).astype("bool"), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :]) + if all_atom_masks is not None: + pseudo_beta_mask = np.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) + pseudo_beta_mask = pseudo_beta_mask.astype(np.float32) + return pseudo_beta, pseudo_beta_mask + return pseudo_beta + + +def make_atom14_masks(aatype): + """create atom 14 position features from aatype""" + rt_atom14_to_atom37 = [] + rt_atom37_to_atom14 = [] + rt_atom14_mask = [] + + for restype in restypes: + atom_names = restype_name_to_atom14_names.get(restype_1to3.get(restype)) + + rt_atom14_to_atom37.append([(atom_order[name] if name else 0) for name in atom_names]) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + rt_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in atom_types]) + + rt_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + rt_atom14_to_atom37.append([0] * 14) + rt_atom37_to_atom14.append([0] * 37) + rt_atom14_mask.append([0.] * 14) + + rt_atom14_to_atom37 = np.array(rt_atom14_to_atom37, np.int32) + rt_atom37_to_atom14 = np.array(rt_atom37_to_atom14, np.int32) + rt_atom14_mask = np.array(rt_atom14_mask, np.float32) + + ri_atom14_to_atom37 = rt_atom14_to_atom37[aatype] + ri_atom14_mask = rt_atom14_mask[aatype] + + atom14_atom_exists = ri_atom14_mask + ri_atom14_to_atom37 = ri_atom14_to_atom37 + + # create the gather indices for mapping back + ri_atom37_to_atom14 = rt_atom37_to_atom14[aatype] + ri_atom37_to_atom14 = ri_atom37_to_atom14 + + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], np.float32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3.get(restype_letter) + atom_names = residue_atoms.get(restype_name) + for atom_name in atom_names: + atom_type = atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + atom37_atom_exists = restype_atom37_mask[aatype] + res = [atom14_atom_exists, ri_atom14_to_atom37, ri_atom37_to_atom14, atom37_atom_exists] + return res + + +def block_delete_msa_indices(msa, msa_fraction_per_block, randomize_num_blocks, num_blocks): + """Sample MSA by deleting contiguous blocks. + + Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion" + + Arguments: + protein: batch dict containing the msa + config: ConfigDict with parameters + + Returns: + updated protein + """ + + num_seq = msa.shape[0] + block_num_seq = np.floor(num_seq * msa_fraction_per_block).astype(np.int32) + + if randomize_num_blocks: + nb = int(np.random.uniform(0, num_blocks + 1)) + else: + nb = num_blocks + del_block_starts = np.random.uniform(0, num_seq, nb).astype(np.int32) + del_blocks = del_block_starts[:, None] + np.array([_ for _ in range(block_num_seq)]).astype(np.int32) + del_blocks = np.clip(del_blocks, 0, num_seq - 1) + del_indices = np.unique(np.sort(np.reshape(del_blocks, (-1,)))) + + # Make sure we keep the original sequence + keep_indices = np.setdiff1d(np.array([_ for _ in range(1, num_seq)]), + del_indices) + keep_indices = np.concatenate([[0], keep_indices], axis=0) + keep_indices = [int(x) for x in keep_indices] + return keep_indices + + +def sample_msa(msa, max_seq): + """Sample MSA randomly, remaining sequences are stored as `extra_*`.""" + num_seq = msa.shape[0] + + shuffled = list(range(1, num_seq)) + np.random.shuffle(shuffled) + shuffled.insert(0, 0) + index_order = np.array(shuffled, np.int32) + num_sel = min(max_seq, num_seq) + + sel_seq = index_order[:num_sel] + not_sel_seq = index_order[num_sel:] + is_sel = num_seq - num_sel + return is_sel, not_sel_seq, sel_seq + + +def gumbel_noise(shape): + """Generate Gumbel Noise of given Shape.""" + epsilon = 1e-6 + uniform_noise = np.random.uniform(0, 1, shape) + gumbel = -np.log(-np.log(uniform_noise + epsilon) + epsilon) + return gumbel + + +def gumbel_argsort_sample_idx(logits): + """Samples with replacement from a distribution given by 'logits'.""" + z = gumbel_noise(logits.shape) + return np.argsort(logits + z, axis=-1)[..., ::-1] + + +def gumbel_permutation(msa_mask, msa_chains=None): + """gumbel permutation.""" + has_msa = np.sum(msa_mask, axis=-1) > 0 + # default logits is zero + logits = np.zeros_like(has_msa, dtype=np.float32) + logits[~has_msa] = -1e6 + # one sample only + assert len(logits.shape) == 1 + # skip first row + logits = logits[1:] + has_msa = has_msa[1:] + if logits.shape[0] == 0: + return np.array([0]) + if msa_chains is not None: + # skip first row + msa_chains = msa_chains[1:].reshape(-1) + msa_chains[~has_msa] = 0 + keys, _ = np.unique(msa_chains, return_counts=True) + num_has_msa = np.array(has_msa.sum()) + num_pair = np.array((msa_chains == 1).sum()) + num_unpair = num_has_msa - num_pair + num_chains = np.array((keys > 1).sum()) + logits[has_msa] = 1.0 / (num_has_msa + 1e-6) + logits[~has_msa] = 0 + for k in keys: + if k > 1: + cur_mask = msa_chains == k + cur_cnt = np.array(cur_mask.sum()) + if cur_cnt > 0: + logits[cur_mask] *= num_unpair / (num_chains * cur_cnt) + logits = np.log(logits + 1e-6) + shuffled = gumbel_argsort_sample_idx(logits) + 1 + return np.concatenate((np.array([0]), shuffled), axis=0) + + +def sample_msa_v2(msa, msa_chains, msa_mask, max_seq, biased_msa_by_chain=False): + """Sample MSA randomly in multimer, remaining sequences are stored as `extra_*`.""" + num_seq = msa.shape[0] + num_sel = min(max_seq, num_seq) + msa_chain = (msa_chains if biased_msa_by_chain else None) + index_order = gumbel_permutation(msa_mask, msa_chain) + num_sel = min(max_seq, num_seq) + sel_seq = index_order[:num_sel] + not_sel_seq = index_order[num_sel:] + is_sel = num_seq - num_sel + return is_sel, not_sel_seq, sel_seq + + +def shape_list(x): + """get the list of dimensions of an array""" + x = np.array(x) + if x.ndim is None: + return x.shape + + static = x.shape + ret = [] + for _, dimension in enumerate(static): + ret.append(dimension) + return ret + + +def shaped_categorical(probability): + """get categorical shape""" + ds = shape_list(probability) + num_classes = ds[-1] + flat_probs = np.reshape(probability, (-1, num_classes)) + numbers = list(range(num_classes)) + res = [] + for flat_prob in flat_probs: + res.append(np.random.choice(numbers, p=flat_prob)) + return np.reshape(np.array(res, np.int32), ds[:-1]) + + +def make_masked_msa(msa, hhblits_profile, uniform_prob, profile_prob, same_prob, replace_fraction, residue_index=None, + msa_mask=None, is_evogen=False): + """create masked msa for BERT on raw MSA features""" + + random_aatype = np.array([0.05] * 20 + [0., 0.], dtype=np.float32) + + probability = uniform_prob * random_aatype + profile_prob * hhblits_profile + same_prob * one_hot(22, msa) + + pad_shapes = [[0, 0] for _ in range(len(probability.shape))] + pad_shapes[-1][1] = 1 + mask_prob = 1. - profile_prob - same_prob - uniform_prob + + probability = np.pad(probability, pad_shapes, constant_values=(mask_prob,)) + + masked_aatype = np.random.uniform(size=msa.shape, low=0, high=1) < replace_fraction + + bert_msa = shaped_categorical(probability) + bert_msa = np.where(masked_aatype, bert_msa, msa) + + bert_mask = masked_aatype.astype(np.int32) + true_msa = msa + msa = bert_msa + if is_evogen: + additional_input = np.concatenate((bert_msa[0][:, None], np.asarray(residue_index)[:, None], + msa_mask[0][:, None], + bert_mask[0][:, None]), + axis=-1).astype(np.int32) + make_masked_msa_result = bert_mask, true_msa, msa, additional_input + + else: + make_masked_msa_result = bert_mask, true_msa, msa + return make_masked_msa_result + + +def share_mask_by_entity(mask_position, entity_id, sym_id, num_sym): + "share mask by entity" + entity_id = entity_id + sym_id = sym_id + num_sym = num_sym + unique_entity_ids = np.unique(entity_id) + first_sym_mask = sym_id == 1 + for cur_entity_id in unique_entity_ids: + cur_entity_mask = entity_id == cur_entity_id + cur_num_sym = int(num_sym[cur_entity_mask][0]) + if cur_num_sym > 1: + cur_sym_mask = first_sym_mask & cur_entity_mask + cur_sym_bert_mask = mask_position[:, cur_sym_mask] + mask_position[:, cur_entity_mask] = cur_sym_bert_mask.repeat(cur_num_sym, 0).reshape( + cur_sym_bert_mask.shape[0], cur_sym_bert_mask.shape[1] * cur_num_sym) + return mask_position + + +def gumbel_max_sample(logits): + """Samples from a probability distribution given by 'logits'.""" + z = gumbel_noise(logits.shape) + return np.argmax(logits + z, axis=-1) + + +def make_masked_msa_v2(msa, hhblits_profile, msa_mask, entity_id, sym_id, num_sym, + uniform_prob, profile_prob, same_prob, + replace_fraction, share_mask=False, bert_mask=None): + """create masked msa for BERT on raw MSA features""" + + random_aatype = np.array([0.05] * 20 + [0., 0.], dtype=np.float32) + probability = uniform_prob * random_aatype + profile_prob * hhblits_profile + same_prob * one_hot(22, msa) + + pad_shapes = [[0, 0] for _ in range(len(probability.shape))] + pad_shapes[-1][1] = 1 + mask_prob = 1.0 - profile_prob - same_prob - uniform_prob + assert mask_prob >= 0.0 + probability = np.pad(probability, pad_shapes, constant_values=(mask_prob,)) + sh = msa.shape + mask_position = np.random.rand(*sh) < replace_fraction + mask_position &= np.array(msa_mask, dtype=bool) + if bert_mask is not None: + mask_position &= np.array(bert_mask, dtype=bool) + + #if share_mask: + # mask_position = share_mask_by_entity(mask_position, entity_id, sym_id, num_sym) + logits = np.log(probability + 1e-6) + bert_msa = gumbel_max_sample(logits) + bert_msa = np.where(mask_position, bert_msa, msa).astype(np.float32) + bert_msa *= msa_mask + + mask_position = np.array(mask_position, dtype=np.float32) + return mask_position, msa, bert_msa + + +def nearest_neighbor_clusters(msa_mask, msa, extra_msa_mask, extra_msa, gap_agreement_weight=0.): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask + weights = np.concatenate([np.ones(21), gap_agreement_weight * np.ones(1), np.zeros(1)], 0) + + # Make agreement score as weighted Hamming distance + sample_one_hot = msa_mask[:, :, None] * one_hot(23, msa) + num_seq, num_res, _ = sample_one_hot.shape + + array_extra_msa_mask = extra_msa_mask + if array_extra_msa_mask.any(): + extra_one_hot = extra_msa_mask[:, :, None] * one_hot(23, extra_msa) + extra_num_seq, _, _ = extra_one_hot.shape + + agreement = np.matmul( + np.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), + np.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).T) + # Assign each sequence in the extra sequences to the closest MSA sample + extra_cluster_assignment = np.argmax(agreement, axis=1) + else: + extra_cluster_assignment = np.array([]) + return extra_cluster_assignment + + +def nearest_neighbor_clusters_v2(msa, msa_mask, extra_msa, extra_msa_mask, + deletion_matrix, extra_deletion_matrix, gap_agreement_weight=0.0): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask. + + weights = np.concatenate([np.ones(21), gap_agreement_weight * np.ones(1), np.zeros(1)], 0) + msa_one_hot = one_hot(23, msa.astype(np.int32)) + extra_one_hot = one_hot(23, extra_msa) + + msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot + extra_one_hot_masked = extra_msa_mask[:, :, None] * extra_one_hot + + t1 = weights * msa_one_hot_masked + t1 = np.resize(t1, (t1.shape[0], t1.shape[1] * t1.shape[2])) + t2 = np.resize(extra_one_hot_masked, (extra_one_hot.shape[0], extra_one_hot.shape[1] * extra_one_hot.shape[2])) + agreement = t1 @ t2.T + cluster_assignment = softmax(1e3 * agreement, axis=0) + cluster_assignment *= np.einsum("mr, nr->mn", msa_mask, extra_msa_mask) + + cluster_count = np.sum(cluster_assignment, axis=-1) + cluster_count += 1.0 # We always include the sequence itself. + + msa_sum = np.einsum("nm, mrc->nrc", cluster_assignment, extra_one_hot_masked) + msa_sum += msa_one_hot_masked + + cluster_profile = msa_sum / cluster_count[:, None, None] + + del_sum = np.einsum( + "nm, mc->nc", cluster_assignment, extra_msa_mask * extra_deletion_matrix + ) + del_sum += deletion_matrix # Original sequence. + cluster_deletion_mean = del_sum / cluster_count[:, None] + + return cluster_profile, cluster_deletion_mean + + +def summarize_clusters(msa, msa_mask, extra_cluster_assignment, extra_msa_mask, extra_msa, extra_deletion_matrix, + deletion_matrix): + """Produce profile and deletion_matrix_mean within each cluster.""" + num_seq = msa.shape[0] + + def csum(x): + result = [] + for i in range(num_seq): + result.append(np.sum(x[np.where(extra_cluster_assignment == i)], axis=0)) + return np.array(result) + + mask = extra_msa_mask + mask_counts = 1e-6 + msa_mask + csum(mask) # Include center + + msa_sum = csum(mask[:, :, None] * one_hot(23, extra_msa)) + msa_sum += one_hot(23, msa) # Original sequence + cluster_profile = msa_sum / mask_counts[:, :, None] + + del msa_sum + + del_sum = csum(mask * extra_deletion_matrix) + del_sum += deletion_matrix # Original sequence + cluster_deletion_mean = del_sum / mask_counts + del del_sum + + return cluster_profile, cluster_deletion_mean + + +def crop_extra_msa(extra_msa, max_extra_msa): + """MSA features are cropped so only `max_extra_msa` sequences are kept.""" + if extra_msa.any(): + num_seq = extra_msa.shape[0] + num_sel = np.minimum(max_extra_msa, num_seq) + shuffled = list(range(num_seq)) + np.random.shuffle(shuffled) + select_indices = shuffled[:num_sel] + return select_indices + return None + + +def make_msa_feat(between_segment_residues, aatype, msa, deletion_matrix, cluster_deletion_mean, cluster_profile, + extra_deletion_matrix): + """Create and concatenate MSA features.""" + # Whether there is a domain break. Always zero for chains, but keeping + # for compatibility with domain datasets. + has_break = np.clip(between_segment_residues.astype(np.float32), np.array(0), np.array(1)) + aatype_1hot = one_hot(21, aatype) + + target_feat = [np.expand_dims(has_break, axis=-1), aatype_1hot] + # target_feat = [aatype_1hot] + + msa_1hot = one_hot(23, msa) + has_deletion = np.clip(deletion_matrix, np.array(0), np.array(1)) + deletion_value = np.arctan(deletion_matrix / 3.) * (2. / np.pi) + + msa_feat = [msa_1hot, np.expand_dims(has_deletion, axis=-1), np.expand_dims(deletion_value, axis=-1)] + + if cluster_profile is not None: + deletion_mean_value = (np.arctan(cluster_deletion_mean / 3.) * (2. / np.pi)) + msa_feat.extend([cluster_profile, np.expand_dims(deletion_mean_value, axis=-1)]) + extra_has_deletion = None + extra_deletion_value = None + if extra_deletion_matrix is not None: + extra_has_deletion = np.clip(extra_deletion_matrix, np.array(0), np.array(1)) + extra_deletion_value = np.arctan(extra_deletion_matrix / 3.) * (2. / np.pi) + + msa_feat = np.concatenate(msa_feat, axis=-1) + target_feat = np.concatenate(target_feat, axis=-1) + res = [extra_has_deletion, extra_deletion_value, msa_feat, target_feat] + return res + + +def make_msa_feat_v2(msa, deletion_matrix, cluster_deletion_mean, cluster_profile): + """Create and concatenate MSA features.""" + msa_1hot = one_hot(23, msa.astype(np.int32)) + has_deletion = np.clip(deletion_matrix, 0.0, 1.0)[..., None] + deletion_value = (np.arctan(deletion_matrix / 3.0) * (2.0 / np.pi))[..., None] + + deletion_mean_value = (np.arctan(cluster_deletion_mean / 3.0) * (2.0 / np.pi))[..., None] + + msa_feat = [ + msa_1hot, + has_deletion, + deletion_value, + cluster_profile, + deletion_mean_value, + ] + msa_feat = np.concatenate(msa_feat, axis=-1) + return msa_feat + + +def make_extra_msa_feat(extra_msa, extra_deletion_matrix, extra_msa_mask, num_extra_msa): + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + extra_msa = extra_msa[:num_extra_msa] + deletion_matrix = extra_deletion_matrix[:num_extra_msa] + has_deletion = np.clip(deletion_matrix, 0.0, 1.0) + deletion_value = np.arctan(deletion_matrix / 3.0) * (2.0 / np.pi) + extra_msa_mask = extra_msa_mask[:num_extra_msa] + return {"extra_msa": extra_msa, + "extra_msa_mask": extra_msa_mask, + "extra_msa_has_deletion": has_deletion, + "extra_msa_deletion_value": deletion_value} + + +def make_random_seed(size, seed_maker_t, low=MS_MIN32, high=MS_MAX32, random_recycle=False): + if random_recycle: + r = np.random.RandomState(seed_maker_t) + return r.uniform(size=size, low=low, high=high) + np.random.seed(seed_maker_t) + return np.random.uniform(size=size, low=low, high=high) + + +def random_crop_to_size(seq_length, template_mask, crop_size, max_templates, + subsample_templates=False, seed=0, random_recycle=False): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + seq_length = seq_length + seq_length_int = int(seq_length) + if template_mask is not None: + num_templates = np.array(template_mask.shape[0], np.int32) + else: + num_templates = np.array(0, np.int32) + num_res_crop_size = np.minimum(seq_length, crop_size) + num_res_crop_size_int = int(num_res_crop_size) + + # Ensures that the cropping of residues and templates happens in the same way + # across ensembling iterations. + # Do not use for randomness that should vary in ensembling. + + if subsample_templates: + templates_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0, high=num_templates + 1, + random_recycle=random_recycle)) + else: + templates_crop_start = 0 + + num_templates_crop_size = np.minimum(num_templates - templates_crop_start, max_templates) + num_templates_crop_size_int = int(num_templates_crop_size) + + num_res_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0, + high=seq_length_int - num_res_crop_size_int + 1, + random_recycle=random_recycle)) + + templates_select_indices = np.argsort(make_random_seed(size=[num_templates], seed_maker_t=seed, + random_recycle=random_recycle)) + res = [num_res_crop_size, num_templates_crop_size_int, num_res_crop_start, num_res_crop_size_int, \ + templates_crop_start, templates_select_indices] + return res + + +def atom37_to_torsion_angles( + aatype: np.ndarray, + all_atom_pos: np.ndarray, + all_atom_mask: np.ndarray, + alt_torsions=False, + is_multimer=False, +): + r""" + This function calculates the seven torsion angles of each residue and encodes them in sine and cosine. + The order of the seven torsion angles is [pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4] + Here, pre_omega represents the twist angle between a given amino acid and the previous amino acid. + The phi represents twist angle between `C-CA-N-(C+1)`, psi represents twist angle between `(N-1)-C-CA-N`. + + Args: + aatype (numpy.array): Amino acid type with shape :math:`(batch\_size, N_{res})`. + all_atom_pos (numpy.array): Atom37 representation of all atomic coordinates with + shape :math:`(batch\_size, N_{res}, 37, 3)`. + all_atom_mask (numpy.array): Atom37 representation of the mask on all atomic coordinates with + shape :math:`(batch\_size, N_{res})`. + alt_torsions (bool): Indicates whether to set the sign angle of shielding torsion to zero. + Default: False. + + Returns: + Dict containing + + - torsion_angles_sin_cos (numpy.array), with shape :math:`(batch\_size, N_{res}, 37, 3)` where + the final 2 dimensions denote sin and cos respectively. + - alt_torsion_angles_sin_cos (numpy.array), same as 'torsion_angles_sin_cos', but with the angle shifted + by pi for all chi angles affected by the naming ambiguities. + - torsion_angles_mask (numpy.array), Mask for which chi angles are present. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.data.data_transform import atom37_to_torsion_angles + >>> n_res = 16 + >>> bs = 1 + >>> aatype = np.random.randn(bs, n_res).astype(np.int32) + >>> all_atom_pos = np.random.randn(bs, n_res, 37, 3).astype(np.float32) + >>> all_atom_mask = np.random.randn(bs, n_res, 37).astype(np.float32) + >>> angle_label_feature = atom37_to_torsion_angles(aatype, all_atom_pos, all_atom_mask) + >>> print(angle_label_feature.keys()) + dict_keys(['torsion_angles_sin_cos', 'alt_torsion_angles_sin_cos', 'torsion_angles_mask']) + """ + + true_aatype = np.minimum(aatype, 20) + + # get the number residue + num_batch, num_res = true_aatype.shape + + paddings = np.zeros([num_batch, 1, 37, 3], np.float32) + padding_atom_pos = np.concatenate([paddings, all_atom_pos[:, :-1, :, :]], axis=1) + paddings = np.zeros([num_batch, 1, 37], np.float32) + padding_atom_mask = np.concatenate([paddings, all_atom_mask[:, :-1, :]], axis=1) + + # compute padding atom position for omega, phi and psi + omega_atom_pos_padding = np.concatenate( + [padding_atom_pos[..., 1:3, :], + all_atom_pos[..., 0:2, :] + ], axis=-2) + phi_atom_pos_padding = np.concatenate( + [padding_atom_pos[..., 2:3, :], + all_atom_pos[..., 0:3, :] + ], axis=-2) + psi_atom_pos_padding = np.concatenate( + [all_atom_pos[..., 0:3, :], + all_atom_pos[..., 4:5, :] + ], axis=-2) + + # compute padding atom position mask for omega, phi and psi + omega_mask_padding = (np.prod(padding_atom_mask[..., 1:3], axis=-1) * + np.prod(all_atom_mask[..., 0:2], axis=-1)) + phi_mask_padding = (padding_atom_mask[..., 2] * np.prod(all_atom_mask[..., 0:3], axis=-1)) + psi_mask_padding = (np.prod(all_atom_mask[..., 0:3], axis=-1) * all_atom_mask[..., 4]) + + chi_atom_pos_indices = get_chi_atom_pos_indices() + if is_multimer: + atom_pos_indices = chi_atom_pos_indices[..., true_aatype, :, :] + else: + atom_pos_indices = np_gather_ops(chi_atom_pos_indices, true_aatype, 0, 0) + + chi_atom_pos = np_gather_ops(all_atom_pos, atom_pos_indices, -2, 2, is_multimer) + + angles_mask = list(chi_angles_mask) + angles_mask.append([0.0, 0.0, 0.0, 0.0]) + angles_mask = np.array(angles_mask) + + if is_multimer: + chis_mask = angles_mask[true_aatype, :] + else: + chis_mask = np_gather_ops(angles_mask, true_aatype, 0, 0) + + chi_angle_atoms_mask = np_gather_ops(all_atom_mask, atom_pos_indices, -1, 2, is_multimer) + + chi_angle_atoms_mask = np.prod(chi_angle_atoms_mask, axis=-1) + chis_mask = chis_mask * chi_angle_atoms_mask.astype(np.float32) + torsions_atom_pos_padding = np.concatenate( + [omega_atom_pos_padding[:, :, None, :, :], + phi_atom_pos_padding[:, :, None, :, :], + psi_atom_pos_padding[:, :, None, :, :], + chi_atom_pos + ], axis=2) + torsion_angles_mask_padding = np.concatenate( + [omega_mask_padding[:, :, None], + phi_mask_padding[:, :, None], + psi_mask_padding[:, :, None], + chis_mask + ], axis=2) + torsion_frames = geometry.rigids_from_3_points( + point_on_neg_x_axis=geometry.vecs_from_tensor(torsions_atom_pos_padding[:, :, :, 1, :]), + origin=geometry.vecs_from_tensor(torsions_atom_pos_padding[:, :, :, 2, :]), + point_on_xy_plane=geometry.vecs_from_tensor(torsions_atom_pos_padding[:, :, :, 0, :])) + inv_torsion_frames = geometry.invert_rigids(torsion_frames) + vecs = geometry.vecs_from_tensor(torsions_atom_pos_padding[:, :, :, 3, :]) + forth_atom_rel_pos = geometry.rigids_mul_vecs(inv_torsion_frames, vecs) + torsion_angles_sin_cos = np.stack( + [forth_atom_rel_pos[2], forth_atom_rel_pos[1]], axis=-1) + torsion_angles_sin_cos /= np.sqrt( + np.sum(np.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + + 1e-8) + + if is_multimer: + torsion_angles_sin_cos = torsion_angles_sin_cos * np.array( + [1., 1., -1., 1., 1., 1., 1.])[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] + chi_is_ambiguous = np.array(chi_pi_periodic)[true_aatype, ...] + else: + torsion_angles_sin_cos *= np.array( + [1., 1., -1., 1., 1., 1., 1.])[None, None, :, None] + + chi_is_ambiguous = np_gather_ops( + np.array(chi_pi_periodic), true_aatype) + + mirror_torsion_angles = np.concatenate( + [np.ones([num_batch, num_res, 3]), + 1.0 - 2.0 * chi_is_ambiguous], axis=-1) + alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) + + if alt_torsions: + fix_torsions = np.stack([np.ones(torsion_angles_sin_cos.shape[:-1]), + np.zeros(torsion_angles_sin_cos.shape[:-1])], axis=-1) + torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask_padding[ + ..., None] + fix_torsions * (1 - torsion_angles_mask_padding[..., None]) + alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask_padding[ + ..., None] + fix_torsions * (1 - torsion_angles_mask_padding[..., None]) + + if is_multimer: + return { + 'torsion_angles_sin_cos': torsion_angles_sin_cos, + 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, + 'torsion_angles_mask': torsion_angles_mask_padding + } + return { + 'torsion_angles_sin_cos': torsion_angles_sin_cos[0], # (N, 7, 2) + 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos[0], # (N, 7, 2) + 'torsion_angles_mask': torsion_angles_mask_padding[0] # (N, 7) + } + + +def atom37_to_frames( + aatype, + all_atom_positions, + all_atom_mask, + is_affine=False +): + r""" + Computes the torsion angle of up to 8 rigid groups for each residue, shape is :math:`[N_{res}, 8, 12]`, + where 8 is indicates that each residue can be divided into up to 8 rigid groups according to the dependence of + the atom on the torsion angle, there are 1 backbone frame and 7 side-chain frames. + For the meaning of 12 ,the first 9 elements are the 9 components of rotation matrix, the last + 3 elements are the 3 component of translation matrix. + + + Args: + aatype(numpy.array): Amino acid sequence, :math:`[N_{res}]` . + all_atom_positions(numpy.array): The coordinates of all atoms, presented as atom37, :math:`[N_{res}, 37,3]`. + all_atom_mask(numpy.array): Mask of all atomic coordinates, :math:`[N_{res},37]`. + is_affine(bool): Whether to perform affine, the default value is False. + + Returns: + Dictionary, the specific content is as follows. + + - **rigidgroups_gt_frames** (numpy.array) - The torsion angle of the 8 rigid body groups for each residue, + :math:`[N_{res}, 8, 12]`. + - **rigidgroups_gt_exists** (numpy.array) - The mask of rigidgroups_gt_frames denoting whether the rigid body + group exists according to the experiment, :math:`[N_{res}, 8]`. + - **rigidgroups_group_exists** (numpy.array) - Mask denoting whether given group is in principle present + for given amino acid type, :math:`[N_{res}, 8]` . + - **rigidgroups_group_is_ambiguous** (numpy.array) - Indicates that the position is chiral symmetry, + :math:`[N_{res}, 8]` . + - **rigidgroups_alt_gt_frames** (numpy.array) - 8 Frames with alternative atom renaming + corresponding to 'all_atom_positions' represented as flat + 12 dimensional array :math:`[N_{res}, 8, 12]` . + - **backbone_affine_tensor** (numpy.array) - The translation and rotation of the local coordinates of each + amino acid relative to the global coordinates, :math:`[N_{res}, 7]` , for the last dimension, the first 4 + elements are the affine tensor which contains the rotation information, the last 3 elements are the + translations in space. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.data import atom37_to_frames + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> aatype = np.ones(193,dtype=np.int32) + >>> all_atom_positions = np.ones((193,37,3),dtype=np.float32) + >>> all_atom_mask = np.ones((193,37),dtype=np.int32) + >>> result = atom37_to_frames(aatype,all_atom_positions,all_atom_mask) + >>> for key in result.keys(): + >>> print(key,result[key].shape) + rigidgroups_gt_frames (193, 8, 12) + rigidgroups_gt_exists (193, 8) + rigidgroups_group_exists (193, 8) + rigidgroups_group_is_ambiguous (193, 8) + rigidgroups_alt_gt_frames (193, 8, 12) + """ + aatype_shape = aatype.shape + + flat_aatype = np.reshape(aatype, [-1]) + all_atom_positions = np.reshape(all_atom_positions, [-1, 37, 3]) + all_atom_mask = np.reshape(all_atom_mask, [-1, 37]) + + rigid_group_names_res = np.full([21, 8, 3], '', dtype=object) + + # group 0: backbone frame + rigid_group_names_res[:, 0, :] = ['C', 'CA', 'N'] + + # group 3: 'psi' + rigid_group_names_res[:, 3, :] = ['CA', 'C', 'O'] + + # group 4,5,6,7: 'chi1,2,3,4' + for restype, letter in enumerate(restypes): + restype_name = restype_1to3[letter] + for chi_idx in range(4): + if chi_angles_mask[restype][chi_idx]: + atom_names = chi_angles_atoms[restype_name][chi_idx] + rigid_group_names_res[restype, chi_idx + 4, :] = atom_names[1:] + + # create rigid group mask + rigid_group_mask_res = np.zeros([21, 8], dtype=np.float32) + rigid_group_mask_res[:, 0] = 1 + rigid_group_mask_res[:, 3] = 1 + rigid_group_mask_res[:20, 4:] = chi_angles_mask + + lookup_table = atom_order.copy() + lookup_table[''] = 0 + rigid_group_atom37_idx_restype = np.vectorize(lambda x: lookup_table[x])( + rigid_group_names_res) + + rigid_group_atom37_idx_residx = np_gather_ops( + rigid_group_atom37_idx_restype, flat_aatype) + + base_atom_pos = np_gather_ops( + all_atom_positions, + rigid_group_atom37_idx_residx, + batch_dims=1) + + gt_frames = geometry.rigids_from_3_points( + point_on_neg_x_axis=geometry.vecs_from_tensor(base_atom_pos[:, :, 0, :]), + origin=geometry.vecs_from_tensor(base_atom_pos[:, :, 1, :]), + point_on_xy_plane=geometry.vecs_from_tensor(base_atom_pos[:, :, 2, :])) + + # get the group mask + group_masks = np_gather_ops(rigid_group_mask_res, flat_aatype) + + # get the atom mask + gt_atoms_exists = np_gather_ops( + all_atom_mask.astype(np.float32), + rigid_group_atom37_idx_residx, + batch_dims=1) + gt_masks = np.min(gt_atoms_exists, axis=-1) * group_masks + + rotations = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) + rotations[0, 0, 0] = -1 + rotations[0, 2, 2] = -1 + gt_frames = geometry.rigids_mul_rots(gt_frames, geometry.rots_from_tensor(rotations, use_numpy=True)) + + rigid_group_is_ambiguous_res = np.zeros([21, 8], dtype=np.float32) + rigid_group_rotations_res = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) + + for restype_name, _ in residue_atom_renaming_swaps.items(): + restype = restype_order[restype_3to1[restype_name]] + chi_idx = int(sum(chi_angles_mask[restype]) - 1) + rigid_group_is_ambiguous_res[restype, chi_idx + 4] = 1 + rigid_group_rotations_res[restype, chi_idx + 4, 1, 1] = -1 + rigid_group_rotations_res[restype, chi_idx + 4, 2, 2] = -1 + + # Gather the ambiguity information for each residue. + rigid_group_is_ambiguous_res_index = np_gather_ops( + rigid_group_is_ambiguous_res, flat_aatype) + rigid_group_ambiguity_rotation_res_index = np_gather_ops( + rigid_group_rotations_res, flat_aatype) + + # Create the alternative ground truth frames. + alt_gt_frames = geometry.rigids_mul_rots( + gt_frames, geometry.rots_from_tensor(rigid_group_ambiguity_rotation_res_index, use_numpy=True)) + + gt_frames_flat12 = np.stack(list(gt_frames[0]) + list(gt_frames[1]), axis=-1) + alt_gt_frames_flat12 = np.stack(list(alt_gt_frames[0]) + list(alt_gt_frames[1]), axis=-1) + # reshape back to original residue layout + gt_frames_flat12 = np.reshape(gt_frames_flat12, aatype_shape + (8, 12)) + gt_masks = np.reshape(gt_masks, aatype_shape + (8,)) + group_masks = np.reshape(group_masks, aatype_shape + (8,)) + gt_frames_flat12 = np.reshape(gt_frames_flat12, aatype_shape + (8, 12)) + rigid_group_is_ambiguous_res_index = np.reshape(rigid_group_is_ambiguous_res_index, aatype_shape + (8,)) + alt_gt_frames_flat12 = np.reshape(alt_gt_frames_flat12, + aatype_shape + (8, 12,)) + if not is_affine: + return { + 'rigidgroups_gt_frames': gt_frames_flat12, # shape (..., 8, 12) + 'rigidgroups_gt_exists': gt_masks, # shape (..., 8) + 'rigidgroups_group_exists': group_masks, # shape (..., 8) + 'rigidgroups_group_is_ambiguous': + rigid_group_is_ambiguous_res_index, # shape (..., 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # shape (..., 8, 12) + } + + rotation = [[gt_frames[0][0], gt_frames[0][1], gt_frames[0][2]], + [gt_frames[0][3], gt_frames[0][4], gt_frames[0][5]], + [gt_frames[0][6], gt_frames[0][7], gt_frames[0][8]]] + translation = [gt_frames[1][0], gt_frames[1][1], gt_frames[1][2]] + backbone_affine_tensor = to_tensor(rotation, translation)[:, 0, :] + return { + 'rigidgroups_gt_frames': gt_frames_flat12, # shape (..., 8, 12) + 'rigidgroups_gt_exists': gt_masks, # shape (..., 8) + 'rigidgroups_group_exists': group_masks, # shape (..., 8) + 'rigidgroups_group_is_ambiguous': rigid_group_is_ambiguous_res_index, # shape (..., 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # shape (..., 8, 12) + 'backbone_affine_tensor': backbone_affine_tensor, # shape (..., 7) + } + + +def get_chi_atom_pos_indices(): + """get the atom indices for computing chi angles for all residue types""" + chi_atom_pos_indices = [] + for residue_name in restypes: + residue_name = restype_1to3[residue_name] + residue_chi_angles = chi_angles_atoms[residue_name] + atom_pos_indices = [] + for chi_angle in residue_chi_angles: + atom_pos_indices.append([atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_pos_indices)): + atom_pos_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_pos_indices.append(atom_pos_indices) + + chi_atom_pos_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return np.array(chi_atom_pos_indices) + + +def gather(params, indices, axis=0): + """gather operation""" + func = lambda p, i: np.take(p, i, axis=axis) + return func(params, indices) + + +def np_gather_ops(params, indices, axis=0, batch_dims=0, is_multimer=False): + """np gather operation""" + if is_multimer: + assert axis < 0 or axis - batch_dims >= 0 + ranges = [] + for i, s in enumerate(params.shape[:batch_dims]): + r = np.arange(s) + r = np.resize(r, (1,) * i + r.shape + (1,) * (len(indices.shape) - i - 1)) + ranges.append(r) + remaining_dims = [slice(None) for _ in range(len(params.shape) - batch_dims)] + remaining_dims[axis - batch_dims if axis >= 0 else axis] = indices + ranges.extend(remaining_dims) + return params[tuple(ranges)] + + if batch_dims == 0: + return gather(params, indices) + result = [] + if batch_dims == 1: + for p, i in zip(params, indices): + axis = axis - batch_dims if axis - batch_dims > 0 else 0 + r = gather(p, i, axis=axis) + result.append(r) + return np.stack(result) + for p, i in zip(params[0], indices[0]): + r = gather(p, i, axis=axis) + result.append(r) + res = np.stack(result) + return res.reshape((1,) + res.shape) + + +def rot_to_quat(rot, unstack_inputs=False): + """transfer the rotation matrix to quaternion matrix""" + if unstack_inputs: + rot = [np.moveaxis(x, -1, 0) for x in np.moveaxis(rot, -2, 0)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + k = [[xx + yy + zz, zy - yz, xz - zx, yx - xy], + [zy - yz, xx - yy - zz, xy + yx, xz + zx], + [xz - zx, xy + yx, yy - xx - zz, yz + zy], + [yx - xy, xz + zx, yz + zy, zz - xx - yy]] + + k = (1. / 3.) * np.stack([np.stack(x, axis=-1) for x in k], + axis=-2) + + # compute eigenvalues + _, qs = np.linalg.eigh(k) + return qs[..., -1] + + +def to_tensor(rotation, translation): + """get affine based on rotation and translation""" + quaternion = rot_to_quat(rotation) + return np.concatenate( + [quaternion] + + [np.expand_dims(x, axis=-1) for x in translation], + axis=-1) + + +def convert_monomer_features(chain_id, aatype, template_aatype): + """Reshapes and modifies monomer features for multimer models.""" + + auth_chain_id = np.asarray(chain_id, dtype=np.object_) + new_order_list = MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + monomer_aatype = np.argmax(aatype, axis=-1).astype(np.int32) + monomer_template_aatype = np.argmax(template_aatype, axis=-1).astype(np.int32) + monomer_template_aatype = np.take(new_order_list, monomer_template_aatype.astype(np.int32), axis=0) + + return auth_chain_id, monomer_aatype, monomer_template_aatype + + +def convert_unnecessary_leading_dim_feats(sequence, domain_name, num_alignments, seq_length): + """get first dimension data of unnecessary features.""" + + monomer_sequence = np.asarray(sequence[0], dtype=sequence.dtype) + monomer_domain_name = np.asarray(domain_name[0], dtype=domain_name.dtype) + monomer_num_alignments = np.asarray(num_alignments[0], dtype=num_alignments.dtype) + monomer_seq_length = np.asarray(seq_length[0], dtype=seq_length.dtype) + + converted_feature = (monomer_sequence, monomer_domain_name, monomer_num_alignments, monomer_seq_length) + return converted_feature + + +def process_unmerged_features(deletion_matrix_int, deletion_matrix_int_all_seq, aatype, entity_id, num_chains): + """Postprocessing stage for per-chain features before merging.""" + # Convert deletion matrices to float. + deletion_matrix = np.asarray(deletion_matrix_int, dtype=np.float32) + deletion_matrix_all_seq = np.asarray(deletion_matrix_int_all_seq, dtype=np.float32) + + all_atom_mask = STANDARD_ATOM_MASK[aatype] + all_atom_mask = all_atom_mask + all_atom_positions = np.zeros(list(all_atom_mask.shape) + [3]) + deletion_mean = np.mean(deletion_matrix, axis=0) + + # Add assembly_num_chains. + assembly_num_chains = np.asarray(num_chains) + entity_mask = (entity_id != 0).astype(np.int32) + post_feature = (deletion_matrix, deletion_matrix_all_seq, deletion_mean, all_atom_mask, all_atom_positions, + assembly_num_chains, entity_mask) + + return post_feature + + +def get_crop_size(num_alignments_all_seq, msa_all_seq, msa_crop_size, msa_size): + """get maximum msa crop size + + Args: + num_alignments_all_seq: num_alignments for all sequence, which record the total number of msa + msa_all_seq: un-paired sequences for all msa. + msa_crop_size: The total number of sequences to crop from the MSA. + msa_size: number of msa + + Returns: + msa_crop_size: msa sized to be cropped + msa_crop_size_all_seq: msa_crop_size for features with "_all_seq" + + """ + + msa_size_all_seq = num_alignments_all_seq + msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2) + + # We reduce the number of un-paired sequences, by the number of times a + # sequence from this chain's MSA is included in the paired MSA. This keeps + # the MSA size for each chain roughly constant. + msa_all_seq = msa_all_seq[:msa_crop_size_all_seq, :] + num_non_gapped_pairs = np.sum(np.any(msa_all_seq != restypes_with_x_and_gap.index('-'), axis=1)) + num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, msa_crop_size_all_seq) + + # Restrict the unpaired crop size so that paired+unpaired sequences do not + # exceed msa_seqs_per_chain for each chain. + max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) + msa_crop_size = np.minimum(msa_size, max_msa_crop_size) + return msa_crop_size, msa_crop_size_all_seq + + +def make_seq_mask(entity_id): + """seq mask info, True for entity_id > 0, False for entity_id <= 0.""" + + seq_mask = (entity_id > 0).astype(np.float32) + return seq_mask + + +def make_msa_mask(msa, entity_id): + """Mask features are all ones, but will later be zero-padded.""" + + msa_mask = np.ones_like(msa, dtype=np.float32) + + seq_mask = (entity_id > 0).astype(np.float32) + msa_mask *= seq_mask[None] + + return msa_mask + + +def add_padding(feature_name, feature): + """get padding data with specified shapes of feature""" + + num_res = feature.shape[1] + padding = MSA_PAD_VALUES.get(feature_name) * np.ones([1, num_res], feature.dtype) + return padding + + +def generate_random_sample(cfg, model_config): + '''generate_random_sample''' + np.random.seed(0) + num_noise = model_config.model.latent.num_noise + latent_dim = model_config.model.latent.latent_dim + + context_true_prob = np.absolute(model_config.train.context_true_prob) + keep_prob = np.absolute(model_config.train.keep_prob) + + available_msa = int(model_config.train.available_msa_fraction * model_config.train.max_msa_clusters) + available_msa = min(available_msa, model_config.train.max_msa_clusters) + + evogen_random_data = np.random.normal( + size=(num_noise, model_config.train.max_msa_clusters, cfg.eval.crop_size, latent_dim)).astype(np.float32) + + # (Nseq,): + context_mask = np.zeros((model_config.train.max_msa_clusters,), np.int32) + z1 = np.random.random(model_config.train.max_msa_clusters) + context_mask = np.asarray([1 if x < context_true_prob else 0 for x in z1], np.int32) + context_mask[available_msa:] *= 0 + + # (Nseq,): + target_mask = np.zeros((model_config.train.max_msa_clusters,), np.int32) + z2 = np.random.random(model_config.train.max_msa_clusters) + target_mask = np.asarray([1 if x < keep_prob else 0 for x in z2], np.int32) + + context_mask[0] = 1 + target_mask[0] = 1 + + evogen_context_mask = np.stack((context_mask, target_mask), -1) + return evogen_random_data, evogen_context_mask + + +def to_tensor_4x4(feature): + rots = feature[..., :9] + trans = feature[..., 9:] + arrays = np.zeros(feature.shape[:-1] + (4, 4)) + rots = np.reshape(rots, rots.shape[:-1] + (3, 3)) + arrays[..., :3, :3] = rots + arrays[..., :3, 3] = trans + arrays[..., 3, 3] = 1 + return arrays diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/elements.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/elements.py new file mode 100644 index 0000000000000000000000000000000000000000..da81dc9b2cdc20b306ef6653e5191408eb48bbf8 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/elements.py @@ -0,0 +1,519 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Information of chemical elements +""" +#pylint: disable=bad-whitespace + +import numpy as np + +elements = np.array([ + '', + 'H', + 'He', + 'Li', + 'Be', + 'B', + 'C', + 'N', + 'O', + 'F', + 'Ne', + 'Na', + 'Mg', + 'Al', + 'Si', + 'P', + 'S', + 'Cl', + 'Ar', + 'K', + 'Ca', + 'Sc', + 'Ti', + 'V', + 'Cr', + 'Mn', + 'Fe', + 'Co', + 'Ni', + 'Cu', + 'Zn', + 'Ga', + 'Ge', + 'As', + 'Se', + 'Br', + 'Kr', + 'Rb', + 'Sr', + 'Y', + 'Zr', + 'Nb', + 'Mo', + 'Tc', + 'Ru', + 'Rh', + 'Pd', + 'Ag', + 'Cd', + 'In', + 'Sn', + 'Sb', + 'Te', + 'I', + 'Xe', + 'Cs', + 'Ba', + 'La', + 'Ce', + 'Pr', + 'Nd', + 'Pm', + 'Sm', + 'Eu', + 'Gd', + 'Tb', + 'Dy', + 'Ho', + 'Er', + 'Tm', + 'Yb', + 'Lu', + 'Hf', + 'Ta', + 'W', + 'Re', + 'Os', + 'Ir', + 'Pt', + 'Au', + 'Hg', + 'Tl', + 'Pb', + 'Bi', + 'Po', + 'At', + 'Rn', + 'Fr', + 'Ra', + 'Ac', + 'Th', + 'Pa', + 'U', + 'Np', + 'Pu', + 'Am', + 'Cm', + 'Bk', + 'Cf', + 'Es', + 'Fm', + 'Md', + 'No', + 'Lr', + 'Rf', + 'Db', + 'Sg', + 'Bh', + 'Hs', + 'Mt', + 'Ds', + 'Rg', + 'Cn', + 'Nh', + 'Fl', + 'Mc', + 'Lv', + 'Ts', + 'Og', +]) + +element_set = set(elements) + +element_dict = { + 'X': 0, + '': 0, + 'H': 1, + 'He': 2, + 'Li': 3, + 'Be': 4, + 'B': 5, + 'C': 6, + 'N': 7, + 'O': 8, + 'F': 9, + 'Ne': 10, + 'Na': 11, + 'Mg': 12, + 'Al': 13, + 'Si': 14, + 'P': 15, + 'S': 16, + 'Cl': 17, + 'Ar': 18, + 'K': 19, + 'Ca': 20, + 'Sc': 21, + 'Ti': 22, + 'V': 23, + 'Cr': 24, + 'Mn': 25, + 'Fe': 26, + 'Co': 27, + 'Ni': 28, + 'Cu': 29, + 'Zn': 30, + 'Ga': 31, + 'Ge': 32, + 'As': 33, + 'Se': 34, + 'Br': 35, + 'Kr': 36, + 'Rb': 37, + 'Sr': 38, + 'Y': 39, + 'Zr': 40, + 'Nb': 41, + 'Mo': 42, + 'Tc': 43, + 'Ru': 44, + 'Rh': 45, + 'Pd': 46, + 'Ag': 47, + 'Cd': 48, + 'In': 49, + 'Sn': 50, + 'Sb': 51, + 'Te': 52, + 'I': 53, + 'Xe': 54, + 'Cs': 55, + 'Ba': 56, + 'La': 57, + 'Ce': 58, + 'Pr': 59, + 'Nd': 60, + 'Pm': 61, + 'Sm': 62, + 'Eu': 63, + 'Gd': 64, + 'Tb': 65, + 'Dy': 66, + 'Ho': 67, + 'Er': 68, + 'Tm': 69, + 'Yb': 70, + 'Lu': 71, + 'Hf': 72, + 'Ta': 73, + 'W': 74, + 'Re': 75, + 'Os': 76, + 'Ir': 77, + 'Pt': 78, + 'Au': 79, + 'Hg': 80, + 'Tl': 81, + 'Pb': 82, + 'Bi': 83, + 'Po': 84, + 'At': 85, + 'Rn': 86, + 'Fr': 87, + 'Ra': 88, + 'Ac': 89, + 'Th': 90, + 'Pa': 91, + 'U': 92, + 'Np': 93, + 'Pu': 94, + 'Am': 95, + 'Cm': 96, + 'Bk': 97, + 'Cf': 98, + 'Es': 99, + 'Fm': 100, + 'Md': 101, + 'No': 102, + 'Lr': 103, + 'Rf': 104, + 'Db': 105, + 'Sg': 106, + 'Bh': 107, + 'Hs': 108, + 'Mt': 109, + 'Ds': 110, + 'Rg': 111, + 'Cn': 112, + 'Nh': 113, + 'Fl': 114, + 'Mc': 115, + 'Lv': 116, + 'Ts': 117, + 'Og': 118, +} + +element_name = np.array([ + 'None', + 'Hydrogen', + 'Helium', + 'Lithium', + 'Beryllium', + 'Boron', + 'Carbon', + 'Nitrogen', + 'Oxygen', + 'Fluorine', + 'Neon', + 'Sodium', + 'Magnesium', + 'Aluminium', + 'Silicon', + 'Phosphorus', + 'Sulfur', + 'Chlorine', + 'Argon', + 'Potassium', + 'Calcium', + 'Scandium', + 'Titanium', + 'Vanadium', + 'Chromium', + 'Manganese', + 'Iron', + 'Cobalt', + 'Nickel', + 'Copper', + 'Zinc', + 'Gallium', + 'Germanium', + 'Arsenic', + 'Selenium', + 'Bromine', + 'Krypton', + 'Rubidium', + 'Strontium', + 'Yttrium', + 'Zirconium', + 'Niobium', + 'Molybdenum', + 'Technetium', + 'Ruthenium', + 'Rhodium', + 'Palladium', + 'Silver', + 'Cadmium', + 'Indium', + 'Tin', + 'Antimony', + 'Tellurium', + 'Iodine', + 'Xenon', + 'Cesium', + 'Barium', + 'Lanthanum', + 'Cerium', + 'Praseodymium', + 'Neodymium', + 'Promethium', + 'Samarium', + 'Europium', + 'Gadolinium', + 'Terbium', + 'Dysprosium', + 'Holmium', + 'Erbium', + 'Thulium', + 'Ytterbium', + 'Lutetium', + 'Hafnium', + 'Tantalum', + 'Tungsten', + 'Rhenium', + 'Osmium', + 'Iridium', + 'Platinum', + 'Gold', + 'Mercury', + 'Thallium', + 'Lead', + 'Bismuth', + 'Polonium', + 'Astatine', + 'Radon', + 'Francium', + 'Radium', + 'Actinium', + 'Thorium', + 'Protactinium', + 'Uranium', + 'Neptunium', + 'Plutonium', + 'Americium', + 'Curium', + 'Berkelium', + 'Californium', + 'Einsteinium', + 'Fermium', + 'Mendelevium', + 'Nobelium', + 'Lawrencium', + 'Rutherfordium', + 'Dubnium', + 'Seaborgium', + 'Bohrium', + 'Hassium', + 'Meitnerium', + 'Darmstadtium', + 'Roentgenium', + 'Copernicium', + 'Nihonium', + 'Flerovium', + 'Moscovium', + 'Livermorium', + 'Tennessine', + 'Oganesson', +]) + +atomic_mass = np.array([ + 0.000, + 1.008, + 4.003, + 6.941, + 9.012, + 10.81, + 12.01, + 14.01, + 16.00, + 19.00, + 20.18, + 22.99, + 24.31, + 26.98, + 28.09, + 30.97, + 32.07, + 35.45, + 39.95, + 39.10, + 40.08, + 44.96, + 47.87, + 50.94, + 52.00, + 54.94, + 55.85, + 58.93, + 58.69, + 63.55, + 65.38, + 69.72, + 72.64, + 74.92, + 78.97, + 79.90, + 83.80, + 85.47, + 87.62, + 88.91, + 91.22, + 92.91, + 95.95, + 98.91, + 101.07, + 102.91, + 106.42, + 107.87, + 112.41, + 114.82, + 118.71, + 121.76, + 127.60, + 126.90, + 131.29, + 132.91, + 137.33, + 138.91, + 140.12, + 140.91, + 144.24, + 144.90, + 150.36, + 151.96, + 157.25, + 158.93, + 162.50, + 164.93, + 167.26, + 168.93, + 173.05, + 174.97, + 178.49, + 180.95, + 183.84, + 186.21, + 190.23, + 192.22, + 195.08, + 196.97, + 200.59, + 204.38, + 207.20, + 208.98, + 208.98, + 209.99, + 222.02, + 223.02, + 226.02, + 227.03, + 232.04, + 231.04, + 238.03, + 237.05, + 239.06, + 243.06, + 247.07, + 247.07, + 251.08, + 252.08, + 257.06, + 258.10, + 259.10, + 262.11, + 267.12, + 268.13, + 269.13, + 274.14, + 277.15, + 278.00, + 281.00, + 282.00, + 285.00, + 284.00, + 289.00, + 288.00, + 292.00, + 294.00, + 295.00, +]) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e82047eb4f53dd65d049c9c18279ff46d4f324d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Export coordinate and trajectory files +""" + +from .h5md import H5MD +from .xyz import export_xyz + +__all__ = ['H5MD', 'export_xyz'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/h5md.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/h5md.py new file mode 100644 index 0000000000000000000000000000000000000000..ed5e733bc83177f1838301abc37111d4ee348211 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/h5md.py @@ -0,0 +1,462 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Export H5MD file. +""" + +import os +import numpy as np +from numpy import ndarray +import h5py +from h5py import Group +from mindspore.train._utils import _make_directory + +from ...system import Molecule +from ...function.units import Units, global_units + +_cur_dir = os.getcwd() + + +class H5MD: + r"""write HDF5 molecular data (H5MD) hdf5_file + + Reference: + + de Buyl, P.; Colberg, P. H., Höfling, F., + H5MD: A structured, efficient, and portable file format for molecular data [J]. + Computer Physics Communications, 2014, 185(6): 1546-1553. + + Args: + + system (Molecule): Simulation system + + filename (str): Name of output H5MD hdf5_file. + + directory (str): Directory of the output hdf5_file. Default: None + + write_velocity (bool): Whether to write the velocity of the system to the H5MD file. + Default: False + + write_force (bool): Whether to write the forece of the system to the H5MD file. + Default: False + + length_unit (str): Length unit for coordinates. + If given "None", it will be equal to the length unit of the system. + Default: None + + energy_unit (str): Energy unit. + If given "None", it will be equal to the global energy unit. + Default: None. + + compression (str): Compression strategy for HDF5. Default: 'gzip' + + compression_opts (int): Compression settings for HDF5. Default: 4 + + mode (str): I/O mode for HDF5. Default: 'w' + + """ + + def __init__(self, + system: Molecule, + filename: str, + directory: str = None, + length_unit: str = None, + energy_unit: str = None, + compression: str = 'gzip', + compression_opts: int = 4, + mode: str = 'w', + ): + + if directory is not None: + self._directory = _make_directory(directory) + else: + self._directory = _cur_dir + self.filename = os.path.join(self._directory, filename) + + self.hdf5_file = h5py.File(self.filename, mode) + + self.h5md = self.hdf5_file.create_group('h5md') + self.h5md.attrs['version'] = [1, 1] + + self.h5md_author = self.h5md.create_group('author') + self.h5md_author.attrs['name'] = 'AIMM Group @ Shenzhen Bay Laboratory & Peking University' + self.h5md_author.attrs['email'] = 'yangyi@szbl.ac.cn' + + self.h5md_creator = self.h5md.create_group('creator') + self.h5md_creator.attrs['name'] = 'MindSPONGE' + self.h5md_creator.attrs['version'] = '0.5' + + if length_unit is None: + length_unit = system.length_unit + if energy_unit is None: + energy_unit = global_units.energy_unit + self.units = Units(length_unit, energy_unit) + + self.num_walker = system.num_walker + self.num_atoms = system.num_atoms + self.dimension = system.dimension + self.coordinate = system.coordinate.asnumpy() + self.crd_shape = (None, self.num_atoms, self.dimension) + + self.pbc_box = system.pbc_box + self.use_pbc = False + if self.pbc_box is not None: + self.pbc_box = system.pbc_box.asnumpy() + self.use_pbc = True + + self.compression = compression + self.compression_opts = compression_opts + + atomic_number = None + if system.atomic_number is not None: + atomic_number = system.atomic_number.asnumpy()[0] + + self.length_unit_scale = self.units.convert_length_from( + system.units) + self.force_unit_scale = self.units.convert_energy_from( + system.units) / self.length_unit_scale + + self.time_unit = 'ps' + + atom_name = None + if system.atom_name is not None: + atom_name = [s.encode('ascii', 'ignore') + for s in system.atom_name[0].tolist()] + + atom_type = None + if system.atom_type is not None: + atom_type = [s.encode('ascii', 'ignore') + for s in system.atom_type[0].tolist()] + + resname = None + if system.residue_name is not None: + resname = [s.encode('ascii', 'ignore') + for s in system.residue_name.tolist()] + + resid = None + if system.atom_resid is not None: + resid = system.atom_resid.asnumpy() + + bond_from = None + bond_to = None + if system.bond is not None: + bond_from = system.bond[0][..., 0].asnumpy() + 1 + bond_to = system.bond[0][..., 1].asnumpy() + 1 + + species = np.arange(self.num_atoms, dtype=np.int32) + + self.parameters = self.hdf5_file.create_group('parameters') + self.vmd_structure = self.create_vmd_structure(species, atomic_number, atom_name, atom_type, + resid, resname, bond_from, bond_to) + + self.shape = (self.num_atoms, self.dimension) + self.particles = self.hdf5_file.create_group('particles') + + if self.num_walker > 1: + self.position = [] + self.velocity = [] + self.force = [] + self.box = [] + self.trajectory = [] + for i in range(self.num_walker): + name = 'trajectory' + str(i) + trajectory = self.create_trajectory(species, name) + self.trajectory.append(trajectory) + + self.position.append(self.create_position( + self.trajectory[i], self.shape)) + self.box.append(self.create_box( + self.trajectory[i], self.shape)) + + else: + self.trajectory = self.create_trajectory(species, 'trajectory') + self.position = self.create_position(self.trajectory, self.shape) + self.box = self.create_box(self.trajectory, self.use_pbc) + + self.image = None + self.edges = None + self.velocity = None + self.force = None + + self.observables = self.hdf5_file.create_group('observables') + self.obs_group = None + + def create_element(self, group: h5py.Group, name: str, shape: tuple, dtype: str, unit: str = None) -> h5py.Group: + """create element""" + element = group.create_group(name) + element.create_dataset('step', shape=(0,), dtype='int32', maxshape=(None,), + compression=self.compression, compression_opts=self.compression_opts) + element.create_dataset('time', shape=(0,), dtype='float32', maxshape=(None,), + compression=self.compression, compression_opts=self.compression_opts) + element.create_dataset('value', shape=(0,)+shape, dtype=dtype, maxshape=(None,)+shape, + compression=self.compression, compression_opts=self.compression_opts) + element['time'].attrs['unit'] = self.time_unit + if unit is not None: + element['value'].attrs['unit'] = unit.encode('ascii', 'ignore') + return element + + def create_vmd_structure(self, + species: ndarray, + atomic_number: ndarray = None, + atom_name: ndarray = None, + atom_type: ndarray = None, + resid: ndarray = None, + resname: ndarray = None, + bond_from: ndarray = None, + bond_to: ndarray = None, + ): + """create the group 'vmd_structure'""" + + vmd_structure = self.parameters.create_group('vmd_structure') + vmd_structure.create_dataset( + 'indexOfSpecies', dtype='int32', data=species, + compression=self.compression, compression_opts=self.compression_opts) + + if atomic_number is not None: + vmd_structure.create_dataset('atomicnumber', dtype='int32', data=atomic_number, + compression=self.compression, compression_opts=self.compression_opts) + if atom_name is not None: + vmd_structure.create_dataset('name', data=atom_name, + compression=self.compression, compression_opts=self.compression_opts) + if atom_type is not None: + vmd_structure.create_dataset('type', data=atom_type, + compression=self.compression, compression_opts=self.compression_opts) + if resid is not None: + vmd_structure.create_dataset('resid', dtype='int32', data=resid, + compression=self.compression, compression_opts=self.compression_opts) + if resname is not None: + vmd_structure.create_dataset('resname', data=resname, + compression=self.compression, compression_opts=self.compression_opts) + if bond_from is not None: + vmd_structure.create_dataset('bond_from', dtype='int32', data=bond_from, + compression=self.compression, compression_opts=self.compression_opts) + vmd_structure.create_dataset('bond_to', dtype='int32', data=bond_to, + compression=self.compression, compression_opts=self.compression_opts) + + return vmd_structure + + def create_trajectory(self, species: ndarray, name: str = 'trajectory') -> h5py.Group: + """create the group 'trajectory'""" + trajectory = self.particles.create_group(name) + trajectory.create_dataset('species', dtype='int32', data=species, + compression=self.compression, compression_opts=self.compression_opts) + return trajectory + + def create_position(self, trajectory: h5py.Group, shape: tuple) -> h5py.Group: + """create the group 'position'""" + return self.create_element(trajectory, 'position', shape, 'float32', self.units.length_unit_name) + + def create_box(self, trajectory: h5py.Group, use_pbc: ndarray = None) -> h5py.Group: + """create the group 'box'""" + box = trajectory.create_group('box') + box.attrs['dimension'] = self.dimension + if use_pbc is None: + box.attrs['boundary'] = ['none'] * self.dimension + else: + box.attrs['boundary'] = ['periodic'] * self.dimension + return box + + def create_edges(self, box: h5py.Group, pbc_box: ndarray = None): + """create edges""" + if pbc_box is None: + edges = self.create_element( + box, 'edges', (self.dimension,), 'float32', self.units.length_unit_name) + else: + pbc_box *= self.length_unit_scale + edges = box.create_dataset('edges', data=pbc_box, dtype='float32', + compression=self.compression, compression_opts=self.compression_opts) + edges.attrs['unit'] = self.units.length_unit_name.encode( + 'ascii', 'ignore') + return edges + + def create_image(self, trajectory: h5py.Group, shape: tuple) -> h5py.Group: + """create the group 'image'""" + return self.create_element(trajectory, 'image', shape, 'int8') + + def create_velocity(self, trajectory: h5py.Group, shape: tuple) -> h5py.Group: + """create the group 'velocity'""" + return self.create_element(trajectory, 'velocity', shape, 'float32', self.units.velocity_unit_name) + + def create_force(self, trajectory: h5py.Group, shape: tuple) -> h5py.Group: + """create the group 'force'""" + return self.create_element(trajectory, 'force', shape, 'float32', self.units.force_unit_name) + + def create_obs_group(self, name: str = 'trajectory') -> h5py.Group: + obs_group = self.observables.create_group(name) + obs_group.attrs['dimension'] = self.dimension + obs_group.create_dataset('particle_number', dtype='int32', data=[self.num_atoms], + compression=self.compression, compression_opts=self.compression_opts) + return obs_group + + def set_box(self, constant_volume: bool = True): + """set PBC box information""" + if self.pbc_box is not None: + if self.num_walker > 1: + self.edges = [] + for i in range(self.num_walker): + if constant_volume: + self.edges.append(self.create_edges( + self.box, self.pbc_box[i])) + else: + self.edges.append(self.create_edges(self.box)) + else: + if constant_volume: + self.edges = self.create_edges(self.box, self.pbc_box[0]) + else: + self.edges = self.create_edges(self.box) + return self + + def set_image(self): + """set group 'image'""" + if self.num_walker > 1: + self.image = [] + for i in range(self.num_walker): + self.force.append(self.create_image( + self.trajectory[i], self.shape)) + else: + self.image = self.create_image(self.trajectory, self.shape) + return self + + def set_velocity(self): + """set group 'velocity'""" + if self.num_walker > 1: + self.velocity = [] + for i in range(self.num_walker): + self.velocity.append(self.create_velocity( + self.trajectory[i], self.shape)) + else: + self.velocity = self.create_velocity(self.trajectory, self.shape) + return self + + def set_force(self): + """set group 'force'""" + if self.num_walker > 1: + self.force = [] + for i in range(self.num_walker): + self.force.append(self.create_force( + self.trajectory[i], self.shape)) + else: + self.force = self.create_force(self.trajectory, self.shape) + return self + + def set_observables(self, names: list, shapes: list, dtypes: list, units: list): + """set observables""" + if self.num_walker > 1: + self.obs_group = [] + for i in range(self.num_walker): + obs_group = self.create_obs_group('trajectory' + str(i)) + for name, shape, dtype, unit in zip(names, shapes, dtypes, units): + self.create_element(obs_group, name, shape, dtype, unit) + self.obs_group.append(obs_group) + else: + self.obs_group = self.create_obs_group('trajectory') + for name, shape, dtype, unit in zip(names, shapes, dtypes, units): + self.create_element(self.obs_group, name, shape, dtype, unit) + return self + + def write_element(self, group: Group, step: int, time: float, value: ndarray): + """write the element to H5MD file""" + ds_step = group['step'] + ds_step.resize(ds_step.shape[0]+1, axis=0) + ds_step[-1] = step + + ds_time = group['time'] + ds_time.resize(ds_time.shape[0]+1, axis=0) + ds_time[-1] = time + + ds_value = group['value'] + ds_value.resize(ds_value.shape[0]+1, axis=0) + ds_value[-1] = value + return self + + def write_position(self, step: int, time: float, position: ndarray): + """write position""" + position *= self.length_unit_scale + if self.num_walker == 1: + self.write_element(self.position, step, time, position[0]) + else: + for i in range(self.num_walker): + self.write_element( + self.position[i], step, time, position[i]) + return self + + def write_box(self, step: int, time: float, box: ndarray): + """write box""" + box *= self.length_unit_scale + if self.num_walker == 1: + self.write_element(self.edges, step, time, box[0]) + else: + for i in range(self.num_walker): + self.write_element(self.edges[i], step, time, box[i]) + return self + + def write_image(self, step: int, time: float, image: ndarray): + """write image""" + if self.num_walker == 1: + self.write_element(self.image, step, time, + image[0].astype(np.int8)) + else: + for i in range(self.num_walker): + self.write_element( + self.image[i], step, time, image[i].astype(np.int8)) + return self + + def write_velocity(self, step: int, time: float, velocity: ndarray): + """write velocity""" + velocity *= self.length_unit_scale + if self.num_walker == 1: + self.write_element(self.velocity, step, time, velocity[0]) + else: + for i in range(self.num_walker): + self.write_element( + self.velocity[i], step, time, velocity[i]) + return self + + def write_force(self, step: int, time: float, force: ndarray): + """write force""" + force *= self.force_unit_scale + if self.num_walker == 1: + self.write_element(self.force, step, time, force[0]) + else: + for i in range(self.num_walker): + self.write_element( + self.force[i], step, time, force[i]) + return self + + def write_observables(self, names: list, step: int, time: float, values: list, index: int = None): + """write observables""" + if index is None and self.num_walker > 1: + raise ValueError( + 'The "index" must given when using muliple walkers') + if self.num_walker == 1: + for name, value in zip(names, values): + self.write_element(self.obs_group[name], step, time, value) + else: + for name, value in zip(names, values): + self.write_element( + self.obs_group[index][name], step, time, value) + return self + + def close(self): + """close the HDF5 file""" + return self.hdf5_file.close() diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/xyz.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/xyz.py new file mode 100644 index 0000000000000000000000000000000000000000..ddac82da7a89d2b7823287952f95d9a6ddb4e26e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/export/xyz.py @@ -0,0 +1,48 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Export xyz files. +""" + +import os +from numpy import ndarray +from ...function.functions import get_ndarray + + +def export_xyz(filename: str, atom: ndarray, coordinate: ndarray, mol_name: str = '', accuracy: str = '{:>12.6f}'): + """export xyx file""" + atom = get_ndarray(atom) + coordinate = get_ndarray(coordinate) + natom = atom.shape[-1] + if coordinate.shape[-2] != natom: + raise ValueError('The penultimate dimension of coordinate (' + + str(coordinate.shape[-2])+') must be equal to the number of atoms (' + + str(natom)+')!') + with open(filename, mode='w+') as ofile: + ofile.write(str(natom)+os.linesep) + ofile.write(' '+mol_name+os.linesep) + for a, r in zip(atom, coordinate): + ofile.write('{:>3d}'.format(a)) + for ri in r: + ofile.write(accuracy.format(ri)) + ofile.write(os.linesep) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bae871a1d5ca92dbe9f625855635c78a7a28ac37 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Force field parameters +""" + +from .forcefield import get_forcefield + +__all__ = ['get_forcefield'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/amber.ff14sb.yaml b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/amber.ff14sb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..161878b237f72201ca3f7e728fd09c522c5568ba --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/amber.ff14sb.yaml @@ -0,0 +1,2137 @@ +template: + base: protein0.yaml + ALA: + atom_type: [N, H, CX, H1, CT, HC, HC, HC, C, O] + atom_charge: [-0.4156928, 0.2718953, 0.0336994, 0.0822986, -0.1824969, 0.060299, + 0.060299, 0.060299, 0.5972897, -0.5678902] + ARG: + atom_type: [N, H, CX, H1, C8, HC, HC, C8, HC, HC, C8, H1, H1, N2, H, CA, N2, + H, H, N2, H, H, C, O] + atom_charge: [-0.347894, 0.2746953, -0.2636955, 0.1559973, -0.0007, 0.0326994, + 0.0326994, 0.0389993, 0.0284995, 0.0284995, 0.0485992, 0.0686988, 0.0686988, + -0.5294908, 0.345594, 0.807586, -0.8626851, 0.4477923, 0.4477923, -0.8626851, + 0.4477923, 0.4477923, 0.7340873, -0.5893898] + ASN: + atom_type: [N, H, CX, H1, 2C, HC, HC, C, O, N, H, H, C, O] + atom_charge: [-0.4156928, 0.2718953, 0.0142998, 0.1047982, -0.2040964, 0.0796986, + 0.0796986, 0.7129877, -0.5930897, -0.9190841, 0.4195927, 0.4195927, 0.5972897, + -0.5678902] + ASP: + atom_type: [N, H, CX, H1, 2C, HC, HC, CO, O2, O2, C, O] + atom_charge: [-0.516291, 0.2935949, 0.0380994, 0.0879985, -0.0302995, -0.0121998, + -0.0121998, 0.7993862, -0.8013861, -0.8013861, 0.5365907, -0.5818899] + CYS: + atom_type: [N, H, CX, H1, 2C, H1, H1, SH, HS, C, O] + atom_charge: [-0.4156928, 0.2718953, 0.0212996, 0.1123981, -0.1230979, 0.1111981, + 0.1111981, -0.3118946, 0.1932967, 0.5972897, -0.5678902] + GLN: + atom_type: [N, H, CX, H1, 2C, HC, HC, 2C, HC, HC, C, O, N, H, H, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0030999, 0.0849986, -0.0035999, 0.0170997, + 0.0170997, -0.0644989, 0.0351994, 0.0351994, 0.695088, -0.6085895, -0.9406837, + 0.4250927, 0.4250927, 0.5972897, -0.5678902] + GLU: + atom_type: [N, H, CX, H1, 2C, HC, HC, 2C, HC, HC, CO, O2, O2, C, O] + atom_charge: [-0.516291, 0.2935949, 0.0396993, 0.1104981, 0.055999, -0.0172997, + -0.0172997, 0.0135997, -0.0424993, -0.0424993, 0.805386, -0.8187858, -0.8187858, + 0.5365907, -0.5818899] + GLY: + atom_type: [N, H, CX, H1, H1, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0251996, 0.0697988, 0.0697988, 0.5972897, + -0.5678902] + HID: + atom_type: [N, H, CX, H1, CT, HC, HC, CC, NA, H, CR, H5, NB, CV, H4, C, O] + atom_charge: [-0.4156928, 0.2718953, 0.0187997, 0.0880985, -0.0461992, 0.0401993, + 0.0401993, -0.0265995, -0.3810934, 0.3648937, 0.2056964, 0.1391976, -0.5726901, + 0.1291978, 0.114698, 0.5972897, -0.5678902] + HIS: + atom_type: [N, H, CX, H1, CT, HC, HC, CC, NB, CR, H5, NA, H, CW, H4, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.058099, 0.1359977, -0.0073999, 0.0366993, + 0.0366993, 0.1867968, -0.5431906, 0.1634972, 0.1434975, -0.2794952, 0.3338942, + -0.2206962, 0.1861968, 0.5972897, -0.5678902] + ILE: + atom_type: [N, H, CX, H1, 3C, HC, CT, HC, HC, HC, 2C, HC, HC, CT, HC, HC, HC, + C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0596989, 0.0868985, 0.1302978, 0.0186997, + -0.3203945, 0.0881985, 0.0881985, 0.0881985, -0.0429993, 0.0235996, 0.0235996, + -0.0659989, 0.0185997, 0.0185997, 0.0185997, 0.5972897, -0.5678902] + LEU: + atom_type: [N, H, CX, H1, 2C, HC, HC, 3C, HC, CT, HC, HC, HC, CT, HC, HC, HC, + C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0517991, 0.0921984, -0.1101981, 0.0456992, + 0.0456992, 0.3530939, -0.0360994, -0.4120929, 0.0999983, 0.0999983, 0.0999983, + -0.4120929, 0.0999983, 0.0999983, 0.0999983, 0.5972897, -0.5678902] + LYS: + atom_type: [N, H, CX, H1, C8, HC, HC, C8, HC, HC, C8, HC, HC, C8, HP, HP, N3, + H, H, H, C, O] + atom_charge: [-0.347894, 0.2746953, -0.2399958, 0.1425975, -0.0093999, 0.0361994, + 0.0361994, 0.0186997, 0.0102998, 0.0102998, -0.0478992, 0.0620989, 0.0620989, + -0.0142998, 0.113498, 0.113498, -0.3853933, 0.3399941, 0.3399941, 0.3399941, + 0.7340873, -0.5893898] + MET: + atom_type: [N, H, CX, H1, 2C, HC, HC, 2C, H1, H1, S, CT, H1, H1, H1, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0236996, 0.0879985, 0.0341994, 0.0240996, + 0.0240996, 0.0018, 0.0439992, 0.0439992, -0.2736953, -0.0535991, 0.0683988, + 0.0683988, 0.0683988, 0.5972897, -0.5678902] + PHE: + atom_type: [N, H, CX, H1, CT, HC, HC, CA, CA, HA, CA, HA, CA, HA, CA, HA, CA, + HA, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0024, 0.0977983, -0.0342994, 0.0294995, + 0.0294995, 0.0117998, -0.1255978, 0.1329977, -0.1703971, 0.1429975, -0.1071982, + 0.1296977, -0.1703971, 0.1429975, -0.1255978, 0.1329977, 0.5972897, -0.5678902] + PRO: + atom_type: [N, CT, H1, H1, CT, HC, HC, CT, HC, HC, CX, H1, C, O] + atom_charge: [-0.2547956, 0.0191997, 0.0390993, 0.0390993, 0.0188996, 0.0212996, + 0.0212996, -0.0069999, 0.0252996, 0.0252996, -0.0265995, 0.0640989, 0.5895898, + -0.57479] + SER: + atom_type: [N, H, CX, H1, 2C, H1, H1, OH, HO, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0248996, 0.0842985, 0.2116963, 0.0351994, + 0.0351994, -0.6545887, 0.4274926, 0.5972897, -0.5678902] + THR: + atom_type: [N, H, CX, H1, 3C, H1, CT, HC, HC, HC, OH, HO, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0388993, 0.1006983, 0.3653937, 0.0042999, + -0.2437958, 0.0641989, 0.0641989, 0.0641989, -0.6760883, 0.4101929, 0.5972897, + -0.5678902] + TRP: + atom_type: [N, H, CX, H1, CT, HC, HC, C*, CW, H4, NA, H, CN, CA, HA, CA, HA, + CA, HA, CA, HA, CB, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0274995, 0.112298, -0.0049999, 0.0338994, + 0.0338994, -0.1414975, -0.1637972, 0.2061964, -0.3417941, 0.3411941, 0.1379976, + -0.2600955, 0.1571973, -0.113398, 0.1416976, -0.1971966, 0.1446975, -0.2386959, + 0.1699971, 0.1242979, 0.5972897, -0.5678902] + TYR: + atom_type: [N, H, CX, H1, CT, HC, HC, CA, CA, HA, CA, HA, C, OH, HO, CA, HA, + CA, HA, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0014, 0.0875985, -0.0151997, 0.0294995, + 0.0294995, -0.0011, -0.1905967, 0.1698971, -0.2340959, 0.1655971, 0.3225944, + -0.5578903, 0.3991931, -0.2340959, 0.1655971, -0.1905967, 0.1698971, 0.5972897, + -0.5678902] + VAL: + atom_type: [N, H, CX, H1, 3C, HC, CT, HC, HC, HC, CT, HC, HC, HC, C, O] + atom_charge: [-0.4156928, 0.2718953, -0.0874985, 0.0968983, 0.2984949, -0.0296995, + -0.3191945, 0.0790986, 0.0790986, 0.0790986, -0.3191945, 0.0790986, 0.0790986, + 0.0790986, 0.5972897, -0.5678902] + NALA: + atom_type: [N3, H, H, H, CX, HP, CT, HC, HC, HC, C, O] + atom_charge: [0.1413975, 0.1996965, 0.1996965, 0.1996965, 0.0961983, 0.0888984, + -0.0596989, 0.0299995, 0.0299995, 0.0299995, 0.6162893, -0.5721901] + NARG: + atom_type: [N3, H, H, H, CX, HP, C8, HC, HC, C8, HC, HC, C8, H1, H1, N2, H, CA, + N2, H, H, N2, H, H, C, O] + atom_charge: [0.1304977, 0.2082964, 0.2082964, 0.2082964, -0.0222996, 0.1241979, + 0.0117998, 0.0225996, 0.0225996, 0.0235996, 0.0308995, 0.0308995, 0.0934984, + 0.0526991, 0.0526991, -0.5649902, 0.3591938, 0.8280857, -0.8692849, 0.4493922, + 0.4493922, -0.8692849, 0.4493922, 0.4493922, 0.7213875, -0.6012896] + NASN: + atom_type: [N3, H, H, H, CX, HP, 2C, HC, HC, C, O, N, H, H, C, O] + atom_charge: [0.1800969, 0.1920967, 0.1920967, 0.1920967, 0.0367994, 0.1230979, + -0.0282995, 0.0514991, 0.0514991, 0.5832899, -0.5743901, -0.8633851, 0.4096929, + 0.4096929, 0.6162893, -0.5721901] + NASP: + atom_type: [N3, H, H, H, CX, HP, 2C, HC, HC, CO, O2, O2, C, O] + atom_charge: [0.0781987, 0.2199962, 0.2199962, 0.2199962, 0.0291995, 0.114098, + -0.0234996, -0.0168997, -0.0168997, 0.8193858, -0.808386, -0.808386, 0.5620903, + -0.5888898] + NCYS: + atom_type: [N3, H, H, H, CX, HP, 2C, H1, H1, SH, HS, C, O] + atom_charge: [0.1324977, 0.2022965, 0.2022965, 0.2022965, 0.0926984, 0.1410976, + -0.1194979, 0.1187979, 0.1187979, -0.3297943, 0.1974966, 0.6122894, -0.5712901] + NGLN: + atom_type: [N3, H, H, H, CX, HP, 2C, HC, HC, 2C, HC, HC, C, O, N, H, H, C, O] + atom_charge: [0.1492974, 0.1995965, 0.1995965, 0.1995965, 0.0535991, 0.1014982, + 0.0650989, 0.0049999, 0.0049999, -0.0902985, 0.0330994, 0.0330994, 0.7353872, + -0.6132894, -1.0030826, 0.4428924, 0.4428924, 0.6122894, -0.5712901] + NGLU: + atom_type: [N3, H, H, H, CX, HP, 2C, HC, HC, 2C, HC, HC, CO, O2, O2, C, O] + atom_charge: [0.0017, 0.2390959, 0.2390959, 0.2390959, 0.058799, 0.1201979, 0.0908984, + -0.0231996, -0.0231996, -0.0235996, -0.0314994, -0.0314994, 0.808686, -0.8188858, + -0.8188858, 0.5620903, -0.5888898] + NGLY: + atom_type: [N3, H, H, H, CX, HP, HP, C, O] + atom_charge: [0.2942949, 0.1641972, 0.1641972, 0.1641972, -0.0099998, 0.0894985, + 0.0894985, 0.6162893, -0.5721901] + NHID: + atom_type: [N3, H, H, H, CX, HP, CT, HC, HC, CC, NA, H, CR, H5, NB, CV, H4, C, + O] + atom_charge: [0.1541973, 0.1962966, 0.1962966, 0.1962966, 0.0963983, 0.0957983, + 0.0258996, 0.0208996, 0.0208996, -0.0398993, -0.3818934, 0.3631937, 0.2126963, + 0.1384976, -0.5710901, 0.1045982, 0.1298978, 0.6122894, -0.5712901] + NHIS: + atom_type: [N3, H, H, H, CX, HP, CT, HC, HC, CC, NB, CR, H5, NA, H, CW, H4, C, + O] + atom_charge: [0.1471975, 0.2015965, 0.2015965, 0.2015965, 0.0235996, 0.1379976, + 0.0488991, 0.0222996, 0.0222996, 0.173997, -0.5578903, 0.1803969, 0.1396976, + -0.2780952, 0.3323943, -0.2348959, 0.1962966, 0.6122894, -0.5712901] + NILE: + atom_type: [N3, H, H, H, CX, HP, 3C, HC, CT, HC, HC, HC, 2C, HC, HC, CT, HC, + HC, HC, C, O] + atom_charge: [0.0310995, 0.232896, 0.232896, 0.232896, 0.0256995, 0.1030982, + 0.1884968, 0.0212996, -0.3719936, 0.0946984, 0.0946984, 0.0946984, -0.0386993, + 0.0200996, 0.0200996, -0.0907984, 0.0225996, 0.0225996, 0.0225996, 0.6122894, + -0.5712901] + NLEU: + atom_type: [N3, H, H, H, CX, HP, 2C, HC, HC, 3C, HC, CT, HC, HC, HC, CT, HC, + HC, HC, C, O] + atom_charge: [0.1009982, 0.2147963, 0.2147963, 0.2147963, 0.0103998, 0.1052982, + -0.0243996, 0.0255996, 0.0255996, 0.3420941, -0.0379993, -0.4105929, 0.0979983, + 0.0979983, 0.0979983, -0.4103929, 0.0979983, 0.0979983, 0.0979983, 0.6122894, + -0.5712901] + NLYS: + atom_type: [N3, H, H, H, CX, HP, C8, HC, HC, C8, HC, HC, C8, HC, HC, C8, HP, + HP, N3, H, H, H, C, O] + atom_charge: [0.0965983, 0.2164963, 0.2164963, 0.2164963, -0.0014999, 0.1179979, + 0.0211996, 0.0282995, 0.0282995, -0.0047999, 0.0120998, 0.0120998, -0.060799, + 0.0632989, 0.0632989, -0.0180997, 0.117098, 0.117098, -0.3763935, 0.3381942, + 0.3381942, 0.3381942, 0.7213875, -0.6012896] + NMET: + atom_type: [N3, H, H, H, CX, HP, 2C, HC, HC, 2C, H1, H1, S, CT, H1, H1, H1, C, + O] + atom_charge: [0.1591972, 0.1983965, 0.1983965, 0.1983965, 0.0220996, 0.1115981, + 0.0864985, 0.0124998, 0.0124998, 0.0333994, 0.0291995, 0.0291995, -0.2773952, + -0.0340994, 0.0596989, 0.0596989, 0.0596989, 0.6122894, -0.5712901] + NPHE: + atom_type: [N3, H, H, H, CX, HP, CT, HC, HC, CA, CA, HA, CA, HA, CA, HA, CA, + HA, CA, HA, C, O] + atom_charge: [0.173697, 0.1920967, 0.1920967, 0.1920967, 0.0732988, 0.1040982, + 0.0329994, 0.0103998, 0.0103998, 0.0030999, -0.1391976, 0.1373976, -0.1601972, + 0.1432975, -0.1207979, 0.1328977, -0.1602972, 0.1432975, -0.1390976, 0.1373976, + 0.6122894, -0.5712901] + NPRO: + atom_type: [N3, H, H, CT, HP, HP, CT, HC, HC, CT, HC, HC, CX, HP, C, O] + atom_charge: [-0.2019965, 0.3119946, 0.3119946, -0.0119998, 0.0999983, 0.0999983, + -0.1209979, 0.0999983, 0.0999983, -0.114998, 0.0999983, 0.0999983, 0.0999983, + 0.0999983, 0.5259909, -0.4999913] + NSER: + atom_type: [N3, H, H, H, CX, HP, 2C, H1, H1, OH, HO, C, O] + atom_charge: [0.1848968, 0.1897967, 0.1897967, 0.1897967, 0.056699, 0.0781987, + 0.2595955, 0.0272995, 0.0272995, -0.6713884, 0.4238927, 0.6162893, -0.5721901] + NTHR: + atom_type: [N3, H, H, H, CX, HP, 3C, H1, CT, HC, HC, HC, OH, HO, C, O] + atom_charge: [0.1811969, 0.1933967, 0.1933967, 0.1933967, 0.0034, 0.1086981, + 0.4513922, -0.0322994, -0.2553956, 0.0626989, 0.0626989, 0.0626989, -0.6763883, + 0.4069929, 0.6162893, -0.5721901] + NTRP: + atom_type: [N3, H, H, H, CX, HP, CT, HC, HC, C*, CW, H4, NA, H, CN, CA, HA, CA, + HA, CA, HA, CA, HA, CB, C, O] + atom_charge: [0.1912967, 0.1887967, 0.1887967, 0.1887967, 0.0420993, 0.116198, + 0.0542991, 0.0221996, 0.0221996, -0.1653971, -0.1787969, 0.2194962, -0.344394, + 0.3411941, 0.1574973, -0.2709953, 0.1588972, -0.1079981, 0.1410976, -0.2033965, + 0.1457975, -0.2264961, 0.1645972, 0.113198, 0.6122894, -0.5712901] + NTYR: + atom_type: [N3, H, H, H, CX, HP, CT, HC, HC, CA, CA, HA, CA, HA, C, OH, HO, CA, + HA, CA, HA, C, O] + atom_charge: [0.1939966, 0.1872968, 0.1872968, 0.1872968, 0.056999, 0.0982983, + 0.0658989, 0.0101998, 0.0101998, -0.0204996, -0.2001965, 0.171997, -0.2238961, + 0.1649972, 0.3138946, -0.5577903, 0.4000931, -0.2238961, 0.1649972, -0.2001965, + 0.171997, 0.6122894, -0.5712901] + NVAL: + atom_type: [N3, H, H, H, CX, HP, 3C, HC, CT, HC, HC, HC, CT, HC, HC, HC, C, O] + atom_charge: [0.057699, 0.2271961, 0.2271961, 0.2271961, -0.0053999, 0.1092981, + 0.3195945, -0.0220996, -0.3128946, 0.0734987, 0.0734987, 0.0734987, -0.3128946, + 0.0734987, 0.0734987, 0.0734987, 0.6162893, -0.5721901] + CALA: + atom_type: [N, H, CX, H1, CT, HC, HC, HC, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.174697, 0.1066981, -0.2092964, 0.0763987, + 0.0763987, 0.0763987, 0.7730866, -0.8054861, -0.8054861] + CARG: + atom_type: [N, H, CX, H1, C8, HC, HC, C8, HC, HC, C8, H1, H1, N2, H, CA, N2, + H, H, N2, H, H, C, O2, O2] + atom_charge: [-0.348094, 0.2763952, -0.3067947, 0.1446975, -0.0373994, 0.0370993, + 0.0370993, 0.0743987, 0.0184997, 0.0184997, 0.1113981, 0.0467992, 0.0467992, + -0.5563904, 0.347894, 0.8367855, -0.8736849, 0.4492922, 0.4492922, -0.8736849, + 0.4492922, 0.4492922, 0.8556852, -0.8265857, -0.8265857] + CASN: + atom_type: [N, H, CX, H1, 2C, HC, HC, C, O, N, H, H, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2079964, 0.1357976, -0.229896, 0.1022982, + 0.1022982, 0.7152876, -0.6009896, -0.9083843, 0.4149928, 0.4149928, 0.8049861, + -0.8146859, -0.8146859] + CASP: + atom_type: [N, H, CX, H1, 2C, HC, HC, CO, O2, O2, C, O2, O2] + atom_charge: [-0.519191, 0.3054947, -0.1816969, 0.1045982, -0.0676988, -0.0211996, + -0.0211996, 0.8850847, -0.8161859, -0.8161859, 0.7255874, -0.7886863, -0.7886863] + CCYS: + atom_type: [N, H, CX, H1, 2C, H1, H1, SH, HS, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.1634972, 0.1395976, -0.1995965, 0.1436975, + 0.1436975, -0.3101946, 0.2067964, 0.749687, -0.7980862, -0.7980862] + CGLN: + atom_type: [N, H, CX, H1, 2C, HC, HC, 2C, HC, HC, C, O, N, H, H, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2247961, 0.1231978, -0.0663989, 0.0451992, + 0.0451992, -0.0209996, 0.0202997, 0.0202997, 0.7092877, -0.6097895, -0.9573834, + 0.4303926, 0.4303926, 0.7774865, -0.8041861, -0.8041861] + CGLU: + atom_type: [N, H, CX, H1, 2C, HC, HC, 2C, HC, HC, CO, O2, O2, C, O2, O2] + atom_charge: [-0.519191, 0.3054947, -0.2058965, 0.1398976, 0.0070999, -0.0077999, + -0.0077999, 0.0674988, -0.054799, -0.054799, 0.8182858, -0.8219858, -0.8219858, + 0.7419872, -0.7929863, -0.7929863] + CGLY: + atom_type: [N, H, CX, H1, H1, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2492957, 0.1055982, 0.1055982, 0.7230875, + -0.7854864, -0.7854864] + CHID: + atom_type: [N, H, CX, H1, CT, HC, HC, CC, NA, H, CR, H5, NB, CV, H4, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.173897, 0.1099981, -0.1045982, 0.056499, + 0.056499, 0.0292995, -0.3891933, 0.3754935, 0.1924967, 0.1417975, -0.5628903, + 0.1000983, 0.1240978, 0.7614868, -0.8015861, -0.8015861] + CHIS: + atom_type: [N, H, CX, H1, CT, HC, HC, CC, NB, CR, H5, NA, H, CW, H4, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2698953, 0.1649972, -0.1067982, 0.0619989, + 0.0619989, 0.2723953, -0.5516905, 0.1557973, 0.1447975, -0.2669954, 0.3318942, + -0.2587955, 0.1956966, 0.7915863, -0.806486, -0.806486] + CILE: + atom_type: [N, H, CX, H1, 3C, HC, CT, HC, HC, HC, 2C, HC, HC, CT, HC, HC, HC, + C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.3099946, 0.1374976, 0.0362993, 0.0765987, + -0.349794, 0.1020982, 0.1020982, 0.1020982, -0.0322994, 0.0320995, 0.0320995, + -0.0698988, 0.0195997, 0.0195997, 0.0195997, 0.8342856, -0.8189858, -0.8189858] + CLEU: + atom_type: [N, H, CX, H1, 2C, HC, HC, 3C, HC, CT, HC, HC, HC, CT, HC, HC, HC, + C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2846951, 0.1345977, -0.2468957, 0.0973983, + 0.0973983, 0.3705936, -0.0373994, -0.4162928, 0.1037982, 0.1037982, 0.1037982, + -0.4162928, 0.1037982, 0.1037982, 0.1037982, 0.8325856, -0.8198858, -0.8198858] + CLYS: + atom_type: [N, H, CX, H1, C8, HC, HC, C8, HC, HC, C8, HC, HC, C8, HP, HP, N3, + H, H, H, C, O2, O2] + atom_charge: [-0.348094, 0.2763952, -0.290295, 0.1437975, -0.0537991, 0.0481992, + 0.0481992, 0.0226996, 0.0133998, 0.0133998, -0.0391993, 0.061099, 0.061099, + -0.0175997, 0.1120981, 0.1120981, -0.3740935, 0.3373942, 0.3373942, 0.3373942, + 0.8487853, -0.8251857, -0.8251857] + CMET: + atom_type: [N, H, CX, H1, 2C, HC, HC, 2C, H1, H1, S, CT, H1, H1, H1, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2596955, 0.1276978, -0.0235996, 0.0479991, + 0.0479991, 0.0491991, 0.0316995, 0.0316995, -0.2691953, -0.0375993, 0.0624989, + 0.0624989, 0.0624989, 0.8012861, -0.810486, -0.810486] + CPHE: + atom_type: [N, H, CX, H1, CT, HC, HC, CA, CA, HA, CA, HA, CA, HA, CA, HA, CA, + HA, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.1824969, 0.1097981, -0.0958984, 0.0442992, + 0.0442992, 0.055199, -0.1299977, 0.1407976, -0.1846968, 0.1460975, -0.0943984, + 0.1279978, -0.1846968, 0.1460975, -0.1299977, 0.1407976, 0.7659868, -0.8025861, + -0.8025861] + CPRO: + atom_type: [N, CT, H1, H1, CT, HC, HC, CT, HC, HC, CX, H1, C, O2, O2] + atom_charge: [-0.2801951, 0.0433993, 0.0330994, 0.0330994, 0.0465992, 0.0171997, + 0.0171997, -0.0542991, 0.0380994, 0.0380994, -0.1335977, 0.0775986, 0.6630885, + -0.7696867, -0.7696867] + CSER: + atom_type: [N, H, CX, H1, 2C, H1, H1, OH, HO, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2721953, 0.1303977, 0.112298, 0.0812986, + 0.0812986, -0.6513887, 0.4473923, 0.811286, -0.8131859, -0.8131859] + CTHR: + atom_type: [N, H, CX, H1, 3C, H1, CT, HC, HC, HC, OH, HO, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2419958, 0.1206979, 0.3024948, 0.0077999, + -0.1852968, 0.058599, 0.058599, 0.058599, -0.6495888, 0.4118928, 0.7809865, + -0.8043861, -0.8043861] + CTRP: + atom_type: [N, H, CX, H1, CT, HC, HC, C*, CW, H4, NA, H, CN, CA, HA, CA, HA, + CA, HA, CA, HA, CB, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2083964, 0.1271978, -0.0741987, 0.0496991, + 0.0496991, -0.0795986, -0.1807969, 0.2042965, -0.3315943, 0.3412941, 0.1221979, + -0.2593955, 0.1566973, -0.1019983, 0.1400976, -0.228696, 0.1506974, -0.1836968, + 0.1490974, 0.1077981, 0.7657867, -0.8010862, -0.8010862] + CTYR: + atom_type: [N, H, CX, H1, CT, HC, HC, CA, CA, HA, CA, HA, C, OH, HO, CA, HA, + CA, HA, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.2014965, 0.1091981, -0.0751987, 0.0489992, + 0.0489992, 0.0242996, -0.1921967, 0.1779969, -0.2457957, 0.1672971, 0.3394941, + -0.5642902, 0.4016931, -0.2457957, 0.1672971, -0.1921967, 0.1779969, 0.7816865, + -0.806986, -0.806986] + CVAL: + atom_type: [N, H, CX, H1, 3C, HC, CT, HC, HC, HC, CT, HC, HC, HC, C, O2, O2] + atom_charge: [-0.3820934, 0.2680954, -0.3437941, 0.1437975, 0.1939966, 0.0307995, + -0.3063947, 0.0835985, 0.0835985, 0.0835985, -0.3063947, 0.0835985, 0.0835985, + 0.0835985, 0.8349855, -0.8172859, -0.8172859] + ACE: + atom_type: [HC, CT, HC, HC, C, O] + atom_charge: [0.112298, -0.3661936, 0.112298, 0.112298, 0.5971897, -0.5678902] + NME: + atom_type: [N, H, CT, H1, H1, H1] + atom_charge: [-0.4156928, 0.2718953, -0.1489974, 0.0975983, 0.0975983, 0.0975983] + +parameters: + bond_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [bond_length, force_constant] + parameters: + OW-HW: [0.09572, 462750.4] + HW-HW: [0.15136, 462750.4] + C-C: [0.1525, 259408.0] + C-CA: [0.1409, 392459.2] + C-CB: [0.1419, 374049.6] + C-CM: [0.1444, 343088.0] + C-CS: [0.1444, 343088.0] + C-CT: [0.1522, 265265.6] + C-CX: [0.1522, 265265.6] + C-N: [0.1335, 410032.0] + C-N*: [0.1383, 354803.2] + C-NA: [0.1388, 349782.4] + C-NC: [0.1358, 382417.6] + C-O: [0.1229, 476976.0] + C-O2: [0.125, 548940.8] + C-OH: [0.1364, 376560.0] + C-OS: [0.1323, 376560.0] + C-H4: [0.108, 307105.6] + C-H5: [0.108, 307105.6] + CA-CA: [0.14, 392459.2] + CA-CB: [0.1404, 392459.2] + CA-CM: [0.1433, 357313.6] + CA-CS: [0.1433, 357313.6] + CA-CN: [0.14, 392459.2] + CA-CT: [0.151, 265265.6] + CA-HA: [0.108, 307105.6] + CA-H4: [0.108, 307105.6] + CA-N2: [0.134, 402500.8] + CA-NA: [0.1381, 357313.6] + CA-NC: [0.1339, 404174.4] + CA-OH: [0.1364, 376560.0] + CB-CB: [0.137, 435136.0] + CB-N*: [0.1374, 364844.8] + CB-NB: [0.1391, 346435.2] + CB-NC: [0.1354, 385764.8] + CD-HA: [0.108, 307105.6] + CD-CD: [0.14, 392459.2] + CD-CM: [0.135, 459403.2] + CD-CS: [0.135, 459403.2] + CD-CT: [0.151, 265265.6] + CK-H5: [0.108, 307105.6] + CK-N*: [0.1371, 368192.0] + CK-NB: [0.1304, 442667.2] + CP-H5: [0.108, 307105.6] + CP-N*: [0.1371, 368192.0] + CP-NB: [0.1304, 442667.2] + CM-CM: [0.135, 459403.2] + CM-CT: [0.151, 265265.6] + CM-HA: [0.108, 307105.6] + CM-H4: [0.108, 307105.6] + CM-H5: [0.108, 307105.6] + CM-N*: [0.1365, 374886.4] + CM-OS: [0.124, 401664.0] + CS-CS: [0.135, 459403.2] + CS-CT: [0.151, 265265.6] + CS-HA: [0.108, 307105.6] + CS-H4: [0.108, 307105.6] + CS-H5: [0.108, 307105.6] + CS-N*: [0.1365, 374886.4] + CS-OS: [0.124, 401664.0] + CQ-H5: [0.108, 307105.6] + CQ-NC: [0.1324, 420073.6] + CT-CT: [0.1526, 259408.0] + CX-CT: [0.1526, 259408.0] + CT-HC: [0.109, 284512.0] + CT-H1: [0.109, 284512.0] + CX-H1: [0.109, 284512.0] + CT-H2: [0.109, 284512.0] + CT-H3: [0.109, 284512.0] + CT-HP: [0.109, 284512.0] + CX-HP: [0.109, 284512.0] + CT-N*: [0.1475, 282001.6] + CT-N2: [0.1463, 282001.6] + CT-OH: [0.141, 267776.0] + CT-OS: [0.141, 267776.0] + C*-HC: [0.108, 307105.6] + C*-CB: [0.1459, 324678.4] + C*-CT: [0.1495, 265265.6] + C*-CW: [0.1352, 456892.8] + CB-CN: [0.1419, 374049.6] + CC-CT: [0.1504, 265265.6] + CC-CV: [0.1375, 428441.6] + CC-CW: [0.1371, 433462.4] + CC-NA: [0.1385, 353129.6] + CC-NB: [0.1394, 343088.0] + CN-NA: [0.138, 358150.4] + CR-H5: [0.108, 307105.6] + CR-NA: [0.1343, 399153.6] + CR-NB: [0.1335, 408358.4] + CT-N: [0.1449, 282001.6] + CX-N: [0.1449, 282001.6] + CT-N3: [0.1471, 307105.6] + CX-N3: [0.1471, 307105.6] + CT-NT: [0.1471, 307105.6] + CT-S: [0.181, 189953.6] + CT-SH: [0.181, 198321.6] + CT-CY: [0.1458, 334720.0] + CT-CZ: [0.1459, 334720.0] + CV-H4: [0.108, 307105.6] + CV-NB: [0.1394, 343088.0] + CW-H4: [0.108, 307105.6] + CW-NA: [0.1381, 357313.6] + CY-NY: [0.115, 502080.0] + CZ-CZ: [0.1206, 502080.0] + CZ-HZ: [0.1056, 334720.0] + OP-P: [0.148, 439320.0] + O2-P: [0.148, 439320.0] + OH-P: [0.161, 192464.0] + OS-P: [0.161, 192464.0] + NA-P: [0.184, 209200.0] + H-N2: [0.101, 363171.2] + H-N*: [0.101, 363171.2] + H-NA: [0.101, 363171.2] + H-N: [0.101, 363171.2] + H-N3: [0.101, 363171.2] + H-NT: [0.101, 363171.2] + HO-OH: [0.096, 462750.4] + HO-OS: [0.096, 462750.4] + HS-SH: [0.1336, 229283.2] + S-S: [0.2038, 138908.8] + F-CT: [0.138, 307105.6] + Cl-CT: [0.1766, 194137.6] + Br-CT: [0.1944, 133051.2] + I-CT: [0.2166, 123846.4] + F-CA: [0.1359, 323004.8] + Cl-CA: [0.1727, 161502.4] + I-CA: [0.2075, 143092.8] + Br-CA: [0.189, 143929.6] + EP-O: [0.02, 502080.0] + EP-OH: [0.02, 502080.0] + EP-OS: [0.02, 502080.0] + EP-N3: [0.02, 502080.0] + EP-NT: [0.02, 502080.0] + EP-NB: [0.02, 502080.0] + EP-NC: [0.02, 502080.0] + EP-S: [0.07, 502080.0] + EP-SH: [0.07, 502080.0] + CI-H1: [0.109, 284512.0] + CI-CT: [0.1526, 259408.0] + OS-CI: [0.141, 267776.0] + OH-CI: [0.141, 267776.0] + C5-H5: [0.108, 307105.6] + C5-N*: [0.1371, 368192.0] + C5-NB: [0.1304, 442667.2] + C-C4: [0.1444, 343088.0] + CA-C4: [0.1433, 357313.6] + C4-C4: [0.135, 459403.2] + C4-CT: [0.151, 265265.6] + C4-HA: [0.108, 307105.6] + C4-H4: [0.108, 307105.6] + C4-N*: [0.1365, 374886.4] + C-2C: [0.1522, 265265.6] + C*-2C: [0.1495, 265265.6] + C8-C8: [0.1526, 259408.0] + C8-CX: [0.1526, 259408.0] + C8-H1: [0.109, 284512.0] + C8-HC: [0.109, 284512.0] + C8-HP: [0.109, 284512.0] + C8-N2: [0.1463, 282001.6] + C8-N3: [0.1471, 307105.6] + CA-2C: [0.151, 265265.6] + CC-2C: [0.1504, 265265.6] + CO-O2: [0.125, 548940.8] + CO-2C: [0.1522, 265265.6] + CT-2C: [0.1526, 259408.0] + CT-3C: [0.1526, 259408.0] + CX-2C: [0.1526, 259408.0] + CX-3C: [0.1526, 259408.0] + H1-2C: [0.109, 284512.0] + H1-3C: [0.109, 284512.0] + HC-2C: [0.109, 284512.0] + HC-3C: [0.109, 284512.0] + OH-2C: [0.141, 267776.0] + OH-3C: [0.141, 267776.0] + S-2C: [0.181, 189953.6] + SH-2C: [0.181, 198321.6] + 2C-2C: [0.1526, 259408.0] + 2C-3C: [0.1526, 259408.0] + angle_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [bond_angle, force_constant] + parameters: + HW-OW-HW: [104.52, 836.8] + HW-HW-OW: [127.74, 0.0] + C-C-O: [120.0, 669.44] + C-C-OH: [120.0, 669.44] + CA-C-CA: [120.0, 527.184] + CA-C-OH: [120.0, 585.76] + CA-C-OS: [120.0, 585.76] + CC-NA-P: [125.1, 641.825574] + CR-NA-P: [125.1, 641.825574] + NA-P-OP: [102.38, 358.987213] + CB-C-NA: [111.3, 585.76] + CB-C-O: [128.8, 669.44] + CM-C-NA: [114.1, 585.76] + CM-C-O: [125.3, 669.44] + CS-C-NA: [114.1, 585.76] + CS-C-O: [125.3, 669.44] + CT-C-O: [120.4, 669.44] + CX-C-O: [120.4, 669.44] + CT-C-O2: [117.0, 585.76] + CX-C-O2: [117.0, 585.76] + CT-C-N: [116.6, 585.76] + CX-C-N: [116.6, 585.76] + CT-C-CT: [117.0, 527.184] + CT-C-OS: [115.0, 669.44] + CT-C-OH: [110.0, 669.44] + CX-C-OH: [110.0, 669.44] + N*-C-NA: [115.4, 585.76] + N*-C-NC: [118.6, 585.76] + N*-C-O: [120.9, 669.44] + NA-C-O: [120.6, 669.44] + NC-C-O: [122.5, 669.44] + N-C-O: [122.9, 669.44] + O-C-O: [126.0, 669.44] + O-C-OH: [120.0, 669.44] + O-C-OS: [125.0, 669.44] + O2-C-O2: [126.0, 669.44] + H4-C-C: [120.0, 418.4] + H4-C-CM: [115.0, 418.4] + H4-C-CS: [115.0, 418.4] + H4-C-CT: [115.0, 418.4] + H4-C-O: [120.0, 418.4] + H4-C-OH: [120.0, 418.4] + H5-C-N: [120.0, 418.4] + H5-C-O: [119.0, 418.4] + H5-C-OH: [107.0, 418.4] + H5-C-OS: [107.0, 418.4] + C-CA-CA: [120.0, 527.184] + C-CA-HA: [120.0, 418.4] + CA-CA-CA: [120.0, 527.184] + CA-CA-CB: [120.0, 527.184] + CA-CA-CT: [120.0, 585.76] + CA-CA-HA: [120.0, 418.4] + CA-CA-H4: [120.0, 418.4] + CA-CA-OH: [120.0, 585.76] + CA-CA-CN: [120.0, 527.184] + CB-CA-HA: [120.0, 418.4] + CB-CA-H4: [120.0, 418.4] + CB-CA-N2: [123.5, 585.76] + CB-CA-NC: [117.3, 585.76] + CM-CA-N2: [120.1, 585.76] + CM-CA-NC: [121.5, 585.76] + CS-CA-N2: [120.1, 585.76] + CS-CA-NC: [121.5, 585.76] + CN-CA-HA: [120.0, 418.4] + NA-CA-NC: [123.3, 585.76] + N2-CA-NA: [116.0, 585.76] + N2-CA-NC: [119.3, 585.76] + N2-CA-N2: [120.0, 585.76] + F-CA-CA: [121.0, 585.76] + Cl-CA-CA: [118.8, 585.76] + Br-CA-CA: [118.8, 585.76] + I-CA-CA: [118.8, 585.76] + C-CB-CB: [119.2, 527.184] + C-CB-NB: [130.0, 585.76] + CA-CB-CB: [117.3, 527.184] + CA-CB-NB: [132.4, 585.76] + CB-CB-N*: [106.2, 585.76] + CB-CB-NB: [110.4, 585.76] + CB-CB-NC: [127.7, 585.76] + C*-CB-CA: [134.9, 527.184] + C*-CB-CN: [108.8, 527.184] + CA-CB-CN: [116.2, 527.184] + N*-CB-NC: [126.2, 585.76] + CD-CD-CM: [120.0, 527.184] + CD-CD-CS: [120.0, 527.184] + CD-CD-CT: [120.0, 585.76] + CM-CD-CT: [120.0, 585.76] + CS-CD-CT: [120.0, 585.76] + HA-CD-HA: [119.0, 292.88] + HA-CD-CD: [120.0, 418.4] + HA-CD-CM: [120.0, 418.4] + HA-CD-CS: [120.0, 418.4] + H5-CK-N*: [123.05, 418.4] + H5-CK-NB: [123.05, 418.4] + N*-CK-NB: [113.9, 585.76] + H5-CP-N*: [123.05, 418.4] + H5-CP-NB: [123.05, 418.4] + N*-CP-NB: [113.9, 585.76] + C-CM-CM: [120.7, 527.184] + C-CM-CT: [119.7, 585.76] + C-CM-HA: [119.7, 418.4] + C-CM-H4: [119.7, 418.4] + CA-CM-CM: [117.0, 527.184] + CA-CM-HA: [123.3, 418.4] + CA-CM-H4: [123.3, 418.4] + CM-CM-CT: [119.7, 585.76] + CM-CM-HA: [119.7, 418.4] + CM-CM-H4: [119.7, 418.4] + CM-CM-N*: [121.2, 585.76] + CM-CM-OS: [125.0, 669.44] + H4-CM-N*: [119.1, 418.4] + H4-CM-OS: [113.0, 418.4] + HA-CM-HA: [120.0, 292.88] + HA-CM-CD: [120.0, 418.4] + HA-CM-CT: [120.0, 418.4] + C-CS-CS: [120.7, 527.184] + C-CS-CT: [119.7, 585.76] + C-CS-HA: [119.7, 418.4] + C-CS-H4: [119.7, 418.4] + CA-CS-CS: [117.0, 527.184] + CA-CS-HA: [123.3, 418.4] + CA-CS-H4: [123.3, 418.4] + CM-CS-CT: [119.7, 585.76] + CS-CS-HA: [119.7, 418.4] + CS-CS-H4: [119.7, 418.4] + CS-CS-N*: [121.2, 585.76] + CS-CS-OS: [125.0, 669.44] + H4-CS-N*: [119.1, 418.4] + H4-CS-OS: [113.0, 418.4] + HA-CS-HA: [120.0, 292.88] + HA-CS-CD: [120.0, 418.4] + HA-CS-CT: [120.0, 418.4] + NC-CQ-NC: [129.1, 585.76] + H5-CQ-NC: [115.45, 418.4] + H1-CT-H1: [109.5, 292.88] + H1-CX-H1: [109.5, 292.88] + H1-CT-N*: [109.5, 418.4] + H1-CT-OH: [109.5, 418.4] + H1-CT-OS: [109.5, 418.4] + H1-CT-CM: [109.5, 418.4] + H1-CT-CS: [109.5, 418.4] + H1-CT-CY: [110.0, 418.4] + H1-CT-CZ: [110.0, 418.4] + H1-CT-N: [109.5, 418.4] + H1-CX-N: [109.5, 418.4] + H1-CT-S: [109.5, 418.4] + H1-CT-SH: [109.5, 418.4] + H1-CT-N2: [109.5, 418.4] + H1-CT-NT: [109.5, 418.4] + H2-CT-H2: [109.5, 292.88] + H2-CT-N*: [109.5, 418.4] + H2-CT-OS: [109.5, 418.4] + HP-CT-HP: [109.5, 292.88] + HP-CX-HP: [109.5, 292.88] + HP-CT-N3: [109.5, 418.4] + HP-CX-N3: [109.5, 418.4] + HC-CT-HC: [109.5, 292.88] + HC-CT-CM: [109.5, 418.4] + HC-CT-CS: [109.5, 418.4] + HC-CT-CD: [109.5, 418.4] + HC-CT-CZ: [110.0, 418.4] + C-CT-H1: [109.5, 418.4] + C-CX-H1: [109.5, 418.4] + C-CT-HP: [109.5, 418.4] + C-CX-HP: [109.5, 418.4] + C-CT-HC: [109.5, 418.4] + C-CT-N: [110.1, 527.184] + C-CX-N: [110.1, 527.184] + C-CT-N3: [111.2, 669.44] + C-CX-N3: [111.2, 669.44] + C-CT-CT: [111.1, 527.184] + C-CT-CX: [111.1, 527.184] + C-CX-CT: [111.1, 527.184] + C-CT-OS: [109.5, 502.08] + CA-CT-HC: [109.5, 418.4] + CC-CT-CT: [113.1, 527.184] + CC-CT-CX: [113.1, 527.184] + CC-CT-HC: [109.5, 418.4] + CM-CT-CT: [111.0, 527.184] + CM-CT-OS: [109.5, 418.4] + CS-CT-CT: [111.0, 527.184] + CS-CT-OS: [109.5, 418.4] + CT-CT-CT: [109.5, 334.72] + CT-CT-CX: [109.5, 334.72] + CT-CT-HC: [109.5, 418.4] + CX-CT-HC: [109.5, 418.4] + CT-CT-H1: [109.5, 418.4] + CT-CX-H1: [109.5, 418.4] + CX-CT-H1: [109.5, 418.4] + CT-CT-H2: [109.5, 418.4] + CT-CT-HP: [109.5, 418.4] + CT-CX-HP: [109.5, 418.4] + CT-CT-N*: [109.5, 418.4] + CT-CT-OH: [109.5, 418.4] + CX-CT-OH: [109.5, 418.4] + CT-CT-OS: [109.5, 418.4] + CT-CT-S: [114.7, 418.4] + CX-CT-S: [114.7, 418.4] + CT-CT-SH: [108.6, 418.4] + CX-CT-SH: [108.6, 418.4] + CT-CT-CA: [114.0, 527.184] + CX-CT-CA: [114.0, 527.184] + CT-CT-N2: [111.2, 669.44] + CT-CT-N: [109.7, 669.44] + CT-CX-N: [109.7, 669.44] + CT-CT-N3: [111.2, 669.44] + CT-CX-N3: [111.2, 669.44] + CT-CT-NT: [111.2, 669.44] + CT-CT-CY: [110.0, 527.184] + CT-CT-CZ: [110.0, 527.184] + C*-CT-CT: [115.6, 527.184] + C*-CT-CX: [115.6, 527.184] + C*-CT-HC: [109.5, 418.4] + OS-CT-OS: [101.0, 1338.88] + OS-CT-CY: [110.0, 418.4] + OS-CT-CZ: [110.0, 418.4] + OS-CT-N*: [109.5, 418.4] + F-CT-F: [109.1, 644.336] + F-CT-H1: [109.5, 418.4] + F-CT-CT: [109.0, 418.4] + F-CT-H2: [109.5, 418.4] + Cl-CT-CT: [108.5, 418.4] + Cl-CT-H1: [108.5, 418.4] + Br-CT-CT: [108.0, 418.4] + Br-CT-H1: [106.5, 418.4] + I-CT-CT: [106.0, 418.4] + CT-CC-NA: [120.0, 585.76] + CT-CC-CV: [120.0, 585.76] + CT-CC-NB: [120.0, 585.76] + CV-CC-NA: [120.0, 585.76] + CW-CC-NA: [120.0, 585.76] + CW-CC-NB: [120.0, 585.76] + CT-CC-CW: [120.0, 585.76] + H5-CR-NA: [120.0, 418.4] + H5-CR-NB: [120.0, 418.4] + NA-CR-NA: [120.0, 585.76] + NA-CR-NB: [120.0, 585.76] + CC-CV-H4: [120.0, 418.4] + CC-CV-NB: [120.0, 585.76] + H4-CV-NB: [120.0, 418.4] + CC-CW-H4: [120.0, 418.4] + CC-CW-NA: [120.0, 585.76] + C*-CW-H4: [120.0, 418.4] + C*-CW-NA: [108.7, 585.76] + H4-CW-NA: [120.0, 418.4] + CB-C*-CT: [128.6, 585.76] + CB-C*-CW: [106.4, 527.184] + CT-C*-CW: [125.0, 585.76] + CA-CN-CB: [122.7, 527.184] + CA-CN-NA: [132.8, 585.76] + CB-CN-NA: [104.4, 585.76] + CT-CY-NY: [180.0, 669.44] + CT-CZ-CZ: [180.0, 669.44] + CZ-CZ-HZ: [180.0, 418.4] + C-N-CT: [121.9, 418.4] + C-N-CX: [121.9, 418.4] + C-N-H: [120.0, 418.4] + CT-N-H: [118.04, 418.4] + CX-N-H: [118.04, 418.4] + CT-N-CT: [118.0, 418.4] + CT-N-CX: [118.0, 418.4] + H-N-H: [120.0, 292.88] + C-N*-CM: [121.6, 585.76] + C-N*-CS: [121.6, 585.76] + C-N*-CT: [117.6, 585.76] + C-N*-H: [119.2, 418.4] + CB-N*-CK: [105.4, 585.76] + CB-N*-CP: [105.4, 585.76] + CB-N*-CT: [125.8, 585.76] + CB-N*-H: [125.8, 418.4] + CK-N*-CT: [128.8, 585.76] + CK-N*-H: [128.8, 418.4] + CP-N*-CT: [128.8, 585.76] + CP-N*-H: [128.8, 418.4] + CM-N*-CT: [121.2, 585.76] + CM-N*-H: [121.2, 418.4] + CS-N*-CT: [121.2, 585.76] + CS-N*-H: [121.2, 418.4] + CA-N2-H: [120.0, 418.4] + CA-N2-CT: [123.2, 418.4] + CT-N2-H: [118.4, 418.4] + H-N2-H: [120.0, 292.88] + CT-N3-H: [109.5, 418.4] + CX-N3-H: [109.5, 418.4] + CT-N3-CT: [109.5, 418.4] + CT-N3-CX: [109.5, 418.4] + H-N3-H: [109.5, 292.88] + CT-NT-H: [109.5, 418.4] + CT-NT-CT: [109.5, 418.4] + H-NT-H: [109.5, 292.88] + C-NA-C: [126.4, 585.76] + C-NA-CA: [125.2, 585.76] + C-NA-H: [116.8, 418.4] + CA-NA-H: [118.0, 418.4] + CC-NA-CR: [120.0, 585.76] + CC-NA-H: [120.0, 418.4] + CR-NA-CW: [120.0, 585.76] + CR-NA-H: [120.0, 418.4] + CW-NA-H: [120.0, 418.4] + CN-NA-CW: [111.6, 585.76] + CN-NA-H: [123.1, 418.4] + CB-NB-CK: [103.8, 585.76] + CB-NB-CP: [103.8, 585.76] + CC-NB-CR: [117.0, 585.76] + CR-NB-CV: [117.0, 585.76] + C-NC-CA: [120.5, 585.76] + CA-NC-CB: [112.2, 585.76] + CA-NC-CQ: [118.6, 585.76] + CB-NC-CQ: [111.0, 585.76] + C-OH-HO: [113.0, 418.4] + CA-OH-HO: [113.0, 418.4] + CT-OH-HO: [108.5, 460.24] + HO-OH-P: [108.5, 376.56] + C-OS-CT: [117.0, 502.08] + CM-OS-CT: [117.0, 502.08] + CS-OS-CT: [117.0, 502.08] + CT-OS-CT: [109.5, 502.08] + CT-OS-P: [120.5, 836.8] + C-OS-P: [120.5, 836.8] + P-OS-P: [120.5, 836.8] + O2-P-OH: [108.23, 376.56] + O2-P-O2: [119.9, 1171.52] + OP-P-OP: [119.9, 1171.52] + OP-P-OS: [108.23, 836.8] + O2-P-OS: [108.23, 836.8] + OH-P-OS: [102.6, 376.56] + OS-P-OS: [102.6, 376.56] + CT-S-CT: [98.9, 518.816] + CT-S-S: [103.7, 569.024] + CT-SH-HS: [96.0, 359.824] + HS-SH-HS: [92.07, 292.88] + CB-NB-EP: [126.0, 1255.2] + CC-NB-EP: [126.0, 1255.2] + CK-NB-EP: [126.0, 1255.2] + CP-NB-EP: [126.0, 1255.2] + CR-NB-EP: [126.0, 1255.2] + CV-NB-EP: [126.0, 1255.2] + C-NC-EP: [120.0, 1255.2] + CA-NC-EP: [120.0, 1255.2] + CB-NC-EP: [120.0, 1255.2] + CQ-NC-EP: [120.0, 1255.2] + CT-N3-EP: [109.5, 1255.2] + H-N3-EP: [109.5, 1255.2] + CT-NT-EP: [109.5, 1255.2] + H-NT-EP: [109.5, 1255.2] + C-O-EP: [120.0, 1255.2] + EP-O-EP: [120.0, 1255.2] + C-OH-EP: [120.0, 1255.2] + CT-OH-EP: [109.5, 1255.2] + HO-OH-EP: [109.5, 1255.2] + EP-OH-EP: [109.5, 1255.2] + C-OS-EP: [109.5, 1255.2] + CM-OS-EP: [109.5, 1255.2] + CS-OS-EP: [109.5, 1255.2] + CT-OS-EP: [109.5, 1255.2] + EP-OS-EP: [109.5, 1255.2] + CT-S-EP: [90.0, 1255.2] + CT-SH-EP: [90.0, 1255.2] + P-OS-EP: [109.5, 1255.2] + EP-S-EP: [180.0, 1255.2] + EP-SH-EP: [180.0, 1255.2] + HS-SH-EP: [90.0, 1255.2] + H1-CI-CT: [109.5, 418.4] + H1-CI-H1: [109.5, 292.88] + CI-CT-H1: [109.5, 418.4] + CI-CT-OS: [109.5, 418.4] + CI-CT-CT: [109.5, 334.72] + OS-CI-H1: [109.5, 418.4] + OS-CI-CT: [109.5, 418.4] + P-OS-CI: [120.5, 836.8] + OH-CI-H1: [109.5, 418.4] + OH-CI-CT: [109.5, 418.4] + HO-OH-CI: [108.5, 460.24] + H5-C5-N*: [123.05, 418.4] + H5-C5-NB: [123.05, 418.4] + N*-C5-NB: [113.9, 585.76] + CB-N*-C5: [105.4, 585.76] + C5-N*-CT: [128.8, 585.76] + CB-NB-C5: [103.8, 585.76] + C4-C-NA: [114.1, 585.76] + C4-C-O: [125.3, 669.44] + C4-CA-N2: [120.1, 585.76] + C4-CA-NC: [121.5, 585.76] + C-C4-C4: [120.7, 527.184] + C-C4-CT: [119.7, 585.76] + C-C4-HA: [119.7, 418.4] + C-C4-H4: [119.7, 418.4] + CA-C4-C4: [117.0, 527.184] + CA-C4-HA: [123.3, 418.4] + CA-C4-H4: [123.3, 418.4] + C4-C4-CT: [119.7, 585.76] + C4-C4-HA: [119.7, 418.4] + C4-C4-H4: [119.7, 418.4] + C4-C4-N*: [121.2, 585.76] + H4-C4-N*: [119.1, 418.4] + H1-CT-C4: [109.5, 418.4] + HC-CT-C4: [109.5, 418.4] + C-N*-C4: [121.6, 585.76] + C4-N*-CT: [121.2, 585.76] + EP-S-S: [96.7, 1255.2] + N-C-2C: [116.6, 585.76] + O-C-2C: [120.4, 669.44] + OH-C-2C: [110.0, 669.44] + CB-C*-2C: [128.6, 585.76] + CW-C*-2C: [125.0, 585.76] + C8-C8-C8: [109.5, 334.72] + C8-C8-CX: [109.5, 334.72] + C8-C8-H1: [109.5, 418.4] + C8-C8-HC: [109.5, 418.4] + C8-C8-HP: [109.5, 418.4] + C8-C8-N2: [111.2, 669.44] + C8-C8-N3: [111.2, 669.44] + CX-C8-HC: [109.5, 418.4] + H1-C8-H1: [109.5, 292.88] + H1-C8-N2: [109.5, 418.4] + HC-C8-HC: [109.5, 292.88] + HP-C8-HP: [109.5, 292.88] + HP-C8-N3: [109.5, 418.4] + CA-CA-2C: [120.0, 585.76] + CV-CC-2C: [120.0, 585.76] + CW-CC-2C: [120.0, 585.76] + NA-CC-2C: [120.0, 585.76] + NB-CC-2C: [120.0, 585.76] + O2-CO-O2: [126.0, 669.44] + O2-CO-2C: [117.0, 585.76] + HC-CT-2C: [109.5, 418.4] + HC-CT-3C: [109.5, 418.4] + C-CX-C8: [111.1, 527.184] + C-CX-2C: [111.1, 527.184] + C-CX-3C: [111.1, 527.184] + C8-CX-H1: [109.5, 418.4] + C8-CX-N: [109.7, 669.44] + C8-CX-N3: [111.2, 669.44] + H1-CX-2C: [109.5, 418.4] + H1-CX-3C: [109.5, 418.4] + HP-CX-C8: [109.5, 418.4] + HP-CX-2C: [109.5, 418.4] + HP-CX-3C: [109.5, 418.4] + N-CX-2C: [109.7, 669.44] + N-CX-3C: [109.7, 669.44] + N3-CX-2C: [111.2, 669.44] + N3-CX-3C: [111.2, 669.44] + C8-N2-CA: [123.2, 418.4] + C8-N2-H: [118.4, 418.4] + C8-N3-H: [109.5, 418.4] + HO-OH-2C: [108.5, 460.24] + HO-OH-3C: [108.5, 460.24] + CT-S-2C: [98.9, 518.816] + 2C-S-S: [103.7, 569.024] + HS-SH-2C: [96.0, 359.824] + C-2C-CX: [111.1, 527.184] + C-2C-HC: [109.5, 418.4] + C-2C-2C: [111.1, 527.184] + C*-2C-CX: [115.6, 527.184] + C*-2C-HC: [109.5, 418.4] + CA-2C-CX: [114.0, 527.184] + CA-2C-HC: [109.5, 418.4] + CC-2C-CX: [113.1, 527.184] + CC-2C-HC: [109.5, 418.4] + CO-2C-CX: [111.1, 527.184] + CO-2C-HC: [109.5, 418.4] + CO-2C-2C: [111.1, 527.184] + CT-2C-HC: [109.5, 418.4] + CT-2C-3C: [109.5, 334.72] + CX-2C-H1: [109.5, 418.4] + CX-2C-HC: [109.5, 418.4] + CX-2C-OH: [109.5, 418.4] + CX-2C-S: [114.7, 418.4] + CX-2C-SH: [108.6, 418.4] + CX-2C-2C: [109.5, 334.72] + CX-2C-3C: [109.5, 334.72] + H1-2C-H1: [109.5, 292.88] + H1-2C-OH: [109.5, 418.4] + H1-2C-S: [109.5, 418.4] + H1-2C-SH: [109.5, 418.4] + H1-2C-2C: [109.5, 418.4] + HC-2C-HC: [109.5, 292.88] + HC-2C-2C: [109.5, 418.4] + HC-2C-3C: [109.5, 418.4] + S-2C-2C: [114.7, 418.4] + CT-3C-CT: [109.5, 334.72] + CT-3C-CX: [109.5, 334.72] + CT-3C-H1: [109.5, 418.4] + CT-3C-HC: [109.5, 418.4] + CT-3C-OH: [109.5, 418.4] + CT-3C-2C: [109.5, 334.72] + CX-3C-H1: [109.5, 418.4] + CX-3C-HC: [109.5, 418.4] + CX-3C-OH: [109.5, 418.4] + CX-3C-2C: [109.5, 334.72] + H1-3C-OH: [109.5, 418.4] + HC-3C-2C: [109.5, 418.4] + dihedral_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: + - [phase, force_constant, periodicity] + parameters: + ?-C-C-?: + - [180.0, 30.334, 2] + ?-C-CA-?: + - [180.0, 30.334, 2] + ?-C-CB-?: + - [180.0, 25.104, 2] + ?-C-CM-?: + - [180.0, 18.2004, 2] + ?-C-CS-?: + - [180.0, 18.2004, 2] + ?-C-CT-?: + - [0.0, 0.0, 0] + ?-C-CX-?: + - [0.0, 0.0, 0] + ?-C-N-?: + - [180.0, 20.92, 2] + ?-C-N*-?: + - [180.0, 12.1336, 2] + ?-C-NA-?: + - [180.0, 11.2968, 2] + ?-C-NC-?: + - [180.0, 33.472, 2] + ?-C-O-?: + - [180.0, 23.4304, 2] + ?-C-OH-?: + - [180.0, 19.2464, 2] + ?-C-OS-?: + - [180.0, 22.5936, 2] + ?-CA-CA-?: + - [180.0, 30.334, 2] + ?-CA-CB-?: + - [180.0, 29.288, 2] + ?-CA-CM-?: + - [180.0, 21.3384, 2] + ?-CA-CS-?: + - [180.0, 21.3384, 2] + ?-CA-CN-?: + - [180.0, 30.334, 2] + ?-CA-CT-?: + - [0.0, 0.0, 0] + ?-CA-N2-?: + - [180.0, 20.083201, 2] + ?-CA-NA-?: + - [180.0, 12.552, 2] + ?-CA-NC-?: + - [180.0, 40.166402, 2] + ?-CA-OH-?: + - [180.0, 7.5312, 2] + ?-CB-CB-?: + - [180.0, 45.605598, 2] + ?-CB-CN-?: + - [180.0, 25.104, 2] + ?-CB-N*-?: + - [180.0, 13.8072, 2] + ?-CB-NB-?: + - [180.0, 21.3384, 2] + ?-CB-NC-?: + - [180.0, 34.727201, 2] + ?-CC-CT-?: + - [0.0, 0.0, 0] + ?-CC-CV-?: + - [180.0, 43.095201, 2] + ?-CC-CW-?: + - [180.0, 44.978, 2] + ?-CC-NA-?: + - [180.0, 11.7152, 2] + ?-CC-NB-?: + - [180.0, 20.083201, 2] + ?-CD-CD-?: + - [180.0, 8.368, 2] + ?-CD-CT-?: + - [0.0, 0.0, 0] + ?-CD-CM-?: + - [180.0, 55.647201, 2] + ?-CD-CS-?: + - [180.0, 55.647201, 2] + ?-CK-N*-?: + - [180.0, 14.2256, 2] + ?-CK-NB-?: + - [180.0, 83.68, 2] + ?-CP-N*-?: + - [180.0, 14.2256, 2] + ?-CP-NB-?: + - [180.0, 83.68, 2] + ?-CM-CM-?: + - [180.0, 55.647201, 2] + ?-CM-CT-?: + - [0.0, 0.0, 0] + ?-CM-N*-?: + - [180.0, 15.4808, 2] + ?-CM-OS-?: + - [180.0, 8.7864, 2] + ?-CS-CS-?: + - [180.0, 55.647201, 2] + ?-CS-CT-?: + - [0.0, 0.0, 0] + ?-CS-N*-?: + - [180.0, 15.4808, 2] + ?-CS-OS-?: + - [180.0, 8.7864, 2] + ?-CN-NA-?: + - [180.0, 12.7612, 2] + ?-CQ-NC-?: + - [180.0, 56.902402, 2] + ?-CT-CT-?: + - [0.0, 1.301689, 3] + ?-CT-CX-?: + - [0.0, 1.301689, 3] + ?-CT-CY-?: + - [0.0, 0.0, 0] + ?-CT-CZ-?: + - [0.0, 0.0, 0] + ?-CT-N-?: + - [0.0, 0.0, 0] + ?-CX-N-?: + - [0.0, 0.0, 0] + ?-CT-N*-?: + - [0.0, 0.0, 0] + ?-CT-N2-?: + - [0.0, 0.0, 0] + ?-CT-NT-?: + - [0.0, 2.5104, 3] + ?-CT-N3-?: + - [0.0, 1.301689, 3] + ?-CX-N3-?: + - [0.0, 1.301689, 3] + ?-CT-OH-?: + - [0.0, 1.394667, 3] + ?-CT-OS-?: + - [0.0, 3.207733, 3] + ?-CT-S-?: + - [0.0, 2.789333, 3] + ?-CT-SH-?: + - [0.0, 2.092, 3] + ?-C*-CB-?: + - [180.0, 14.0164, 2] + ?-C*-CT-?: + - [0.0, 0.0, 0] + ?-C*-CW-?: + - [180.0, 54.601201, 2] + ?-CR-NA-?: + - [180.0, 19.4556, 2] + ?-CR-NB-?: + - [180.0, 41.84, 2] + ?-CV-NB-?: + - [180.0, 20.083201, 2] + ?-CW-NA-?: + - [180.0, 12.552, 2] + ?-OH-P-?: + - [0.0, 2.092, 3] + ?-OS-P-?: + - [0.0, 2.092, 3] + ?-CI-OS-?: + - [0.0, 3.207733, 3] + ?-CI-OH-?: + - [0.0, 1.394667, 3] + ?-CI-CT-?: + - [0.0, 1.301689, 3] + ?-C5-N*-?: + - [180.0, 14.2256, 2] + ?-C5-NB-?: + - [180.0, 83.68, 2] + ?-C-C4-?: + - [180.0, 18.2004, 2] + ?-CA-C4-?: + - [180.0, 21.3384, 2] + ?-C4-C4-?: + - [180.0, 55.647201, 2] + ?-C4-CT-?: + - [0.0, 0.0, 0] + ?-C4-N*-?: + - [180.0, 15.4808, 2] + CT-OS-CT-CI: + - [180.0, 0.8368, 2] + - [0.0, 3.204944, 3] + H1-CI-CT-OS: + - [0.0, 2.092, 1] + H1-CI-CT-OH: + - [0.0, 2.092, 1] + H1-CT-CI-OS: + - [0.0, 2.092, 1] + H1-CT-CI-OH: + - [0.0, 2.092, 1] + CI-CT-CT-CT: + - [180.0, 1.6736, 1] + - [180.0, 2.092, 2] + - [0.0, 1.50624, 3] + OP-P-OS-CA: + - [0.0, 0.0, 0] + CC-NA-P-OP: + - [0.0, 2.00832, 3] + CR-NA-P-OP: + - [0.0, 0.0, 0] + OS-P-OS-CI: + - [357.2475, 2.969452, 3] + - [351.9596, 10.514651, 2] + - [31.7951, 1.549595, 1] + OH-P-OS-CI: + - [357.2475, 2.969452, 3] + - [351.9596, 10.514651, 2] + - [31.7951, 1.549595, 1] + CT-CT-CI-OS: + - [348.0953, 8.056961, 3] + - [295.6328, 0.77071, 2] + - [190.9765, 9.857839, 1] + CT-CT-CI-OH: + - [348.0953, 8.056961, 3] + - [295.6328, 0.77071, 2] + - [190.9765, 9.857839, 1] + HC-CT-C4-C4: + - [0.0, 9.6232, 1] + - [180.0, 3.17984, 3] + C4-C4-C-O: + - [0.0, 2.5104, 3] + - [180.0, 18.2004, 2] + OS-CT-N*-C5: + - [19.0921, 2.587135, 4] + - [171.5787, 3.828695, 3] + - [15.636, 8.987483, 2] + - [68.7902, 8.080225, 1] + OS-CT-N*-CP: + - [3.9746, 2.142375, 4] + - [168.6503, 3.704765, 3] + - [6.2286, 8.915769, 2] + - [74.7558, 5.900277, 1] + OS-CT-N*-C4: + - [32.159, 2.596841, 4] + - [185.8774, 7.844749, 3] + - [16.4766, 13.678249, 2] + - [146.9892, 10.251302, 1] + OS-CT-N*-CS: + - [16.0016, 2.941185, 4] + - [179.3474, 4.865992, 3] + - [16.7648, 14.633624, 2] + - [149.8583, 8.578372, 1] + C-N-CX-C: + - [0.0, 2.25936, 2] + - [0.0, 3.51456, 3] + N-CX-C-N: + - [180.0, 3.7656, 1] + - [180.0, 13.22144, 2] + - [180.0, 4.6024, 3] + CT-CT-N-C: + - [0.0, 16.736, 1] + - [0.0, 16.736, 2] + - [0.0, 3.3472, 3] + CT-CX-N-C: + - [0.0, 16.736, 1] + - [0.0, 15.0624, 2] + - [0.0, 6.6944, 3] + CT-CT-C-N: + - [0.0, 1.6736, 1] + - [0.0, 1.6736, 2] + - [0.0, 3.3472, 3] + CT-CX-C-N: + - [0.0, 1.6736, 1] + - [0.0, 1.6736, 2] + - [0.0, 3.3472, 3] + CX-CT-C-N: + - [0.0, 1.6736, 1] + - [0.0, 1.6736, 2] + - [0.0, 3.3472, 3] + H-N-C-O: + - [0.0, 16.736, 1] + - [180.0, 20.92, 2] + CT-S-S-CT: + - [0.0, 5.0208, 3] + - [0.0, 29.288, 2] + OH-P-OS-CT: + - [0.0, 10.0416, 2] + - [0.0, 2.092, 3] + OS-P-OS-CT: + - [0.0, 10.0416, 2] + - [0.0, 2.092, 3] + H1-CT-C-O: + - [180.0, 0.66944, 3] + - [0.0, 6.6944, 1] + H1-CX-C-O: + - [180.0, 0.66944, 3] + - [0.0, 6.6944, 1] + HC-CT-C-O: + - [180.0, 0.66944, 3] + - [0.0, 6.6944, 1] + HC-CT-CT-HC: + - [0.0, 1.2552, 3] + HC-CT-CT-CT: + - [0.0, 1.33888, 3] + HC-CT-CT-CX: + - [0.0, 1.33888, 3] + HC-CT-CM-CM: + - [0.0, 9.6232, 1] + - [180.0, 3.17984, 3] + HC-CT-CS-CS: + - [0.0, 9.6232, 1] + - [180.0, 3.17984, 3] + HO-OH-CT-CT: + - [0.0, 2.092, 1] + - [0.0, 1.33888, 3] + HO-OH-CT-CX: + - [0.0, 2.092, 1] + - [0.0, 1.33888, 3] + HO-OH-C-O: + - [0.0, 15.8992, 1] + - [180.0, 19.2464, 2] + CM-CM-C-O: + - [0.0, 2.5104, 3] + - [180.0, 18.2004, 2] + CT-CM-CM-CT: + - [180.0, 15.8992, 1] + - [180.0, 55.647201, 2] + CS-CS-C-O: + - [0.0, 2.5104, 3] + - [180.0, 18.2004, 2] + CT-CS-CS-CT: + - [180.0, 15.8992, 1] + - [180.0, 55.647201, 2] + CT-CT-CT-CT: + - [180.0, 1.6736, 1] + - [180.0, 2.092, 2] + - [0.0, 1.50624, 3] + CX-CT-CT-CT: + - [180.0, 1.6736, 1] + - [180.0, 2.092, 2] + - [0.0, 1.50624, 3] + CT-CT-NT-CT: + - [180.0, 4.01664, 2] + - [0.0, 2.5104, 3] + CT-CT-OS-CT: + - [180.0, 0.8368, 2] + - [0.0, 3.204944, 3] + CT-CT-OS-C: + - [180.0, 6.6944, 1] + - [0.0, 3.204944, 3] + CT-OS-CT-OS: + - [180.0, 11.2968, 1] + - [180.0, 7.1128, 2] + - [0.0, 0.8368, 3] + CT-OS-CT-N*: + - [0.0, 5.4392, 2] + - [0.0, 3.204944, 3] + CT-CZ-CZ-HZ: + - [0.0, 0.0, 0] + O-C-OS-CT: + - [180.0, 11.7152, 1] + - [180.0, 22.5936, 2] + OS-CT-N*-CK: + - [0.0, 20.92, 1] + OS-CT-N*-CM: + - [0.0, 20.92, 1] + OS-CT-CT-OS: + - [0.0, 9.8324, 2] + - [0.0, 1.204992, 3] + OS-CT-CT-OH: + - [0.0, 9.8324, 2] + - [0.0, 1.204992, 3] + OH-CT-CT-OH: + - [0.0, 9.8324, 2] + - [0.0, 1.204992, 3] + F-CT-CT-F: + - [180.0, 10.0416, 1] + Cl-CT-CT-Cl: + - [180.0, 3.7656, 1] + Br-CT-CT-Br: + - [0.0, 0.0, 0] + H1-CT-CT-OS: + - [0.0, 2.092, 1] + H1-CT-CT-OH: + - [0.0, 2.092, 1] + H1-CX-CT-OH: + - [0.0, 2.092, 1] + H1-CT-CT-F: + - [0.0, 1.58992, 1] + H1-CT-CT-Cl: + - [0.0, 2.092, 1] + H1-CT-CT-Br: + - [0.0, 4.6024, 1] + HC-CT-CT-OS: + - [0.0, 2.092, 1] + HC-CT-CT-OH: + - [0.0, 2.092, 1] + HC-CT-CT-F: + - [0.0, 1.58992, 1] + HC-CT-CT-Cl: + - [0.0, 2.092, 1] + HC-CT-CT-Br: + - [0.0, 4.6024, 1] + H1-CT-NT-EP: + - [0.0, 0.0, 0] + CT-CT-NT-EP: + - [0.0, 0.0, 0] + CT-C-N-EP: + - [0.0, 0.0, 0] + O-C-N-EP: + - [0.0, 0.0, 0] + H1-CT-OH-EP: + - [0.0, 0.0, 0] + CT-CT-OH-EP: + - [0.0, 0.0, 0] + H1-CT-OS-EP: + - [0.0, 0.0, 0] + H2-CT-OS-EP: + - [0.0, 0.0, 0] + CT-CT-OS-EP: + - [0.0, 0.0, 0] + CM-CM-OS-EP: + - [0.0, 0.0, 0] + HA-CM-OS-EP: + - [0.0, 0.0, 0] + H4-CM-OS-EP: + - [0.0, 0.0, 0] + CS-CS-OS-EP: + - [0.0, 0.0, 0] + HA-CS-OS-EP: + - [0.0, 0.0, 0] + H4-CS-OS-EP: + - [0.0, 0.0, 0] + N-CT-CT-OH: + - [0.0, 1.305408, 3] + - [0.0, 12.46832, 2] + EP-S-S-CT: + - [0.0, 0.0, 0] + EP-S-S-EP: + - [0.0, 0.0, 0] + C8-CX-N-C: + - [0.0, 16.736, 1] + - [0.0, 15.0624, 2] + - [0.0, 6.6944, 3] + 2C-CX-N-C: + - [0.0, 16.736, 1] + - [0.0, 15.0624, 2] + - [0.0, 6.6944, 3] + 3C-CX-N-C: + - [0.0, 16.736, 1] + - [0.0, 15.0624, 2] + - [0.0, 6.6944, 3] + N-C-CX-C8: + - [0.0, 1.6736, 1] + - [0.0, 1.6736, 2] + - [0.0, 3.3472, 3] + N-C-CX-2C: + - [0.0, 1.6736, 1] + - [0.0, 1.6736, 2] + - [0.0, 3.3472, 3] + N-C-CX-3C: + - [0.0, 1.6736, 1] + - [0.0, 1.6736, 2] + - [0.0, 3.3472, 3] + N-C-2C-HC: + - [0.0, 0.0, 0] + O-C-2C-HC: + - [0.0, 6.6944, 1] + - [180.0, 0.66944, 3] + OH-C-2C-HC: + - [0.0, 0.0, 0] + CB-C*-2C-HC: + - [0.0, 0.0, 0] + CW-C*-2C-HC: + - [0.0, 0.0, 0] + ?-C8-C8-?: + - [0.0, 1.301689, 3] + C8-C8-C8-HC: + - [0.0, 1.33888, 3] + CX-C8-C8-HC: + - [0.0, 1.33888, 3] + HC-C8-C8-HC: + - [0.0, 1.2552, 3] + ?-C8-CX-?: + - [0.0, 1.301689, 3] + ?-C8-N2-?: + - [0.0, 0.0, 0] + ?-C8-N3-?: + - [0.0, 1.301689, 3] + ?-CA-2C-?: + - [0.0, 0.0, 0] + 2C-CC-CV-H4: + - [180.0, 43.095201, 2] + 2C-CC-CV-NB: + - [180.0, 43.095201, 2] + 2C-CC-CW-H4: + - [180.0, 44.978, 2] + 2C-CC-CW-NA: + - [180.0, 44.978, 2] + 2C-CC-NA-CR: + - [180.0, 11.7152, 2] + 2C-CC-NA-H: + - [180.0, 11.7152, 2] + 2C-CC-NB-CR: + - [180.0, 20.083201, 2] + CV-CC-2C-HC: + - [0.0, 0.0, 0] + CW-CC-2C-HC: + - [0.0, 0.0, 0] + NA-CC-2C-HC: + - [0.0, 0.0, 0] + NB-CC-2C-HC: + - [0.0, 0.0, 0] + O2-CO-2C-HC: + - [0.0, 0.0, 0] + H1-CT-S-2C: + - [0.0, 2.789333, 3] + HC-CT-2C-HC: + - [0.0, 1.2552, 3] + HC-CT-2C-3C: + - [0.0, 1.33888, 3] + HC-CT-3C-CT: + - [0.0, 1.33888, 3] + HC-CT-3C-CX: + - [0.0, 1.33888, 3] + HC-CT-3C-H1: + - [0.0, 1.301689, 3] + HC-CT-3C-HC: + - [0.0, 1.2552, 3] + HC-CT-3C-OH: + - [0.0, 2.092, 1] + HC-CT-3C-2C: + - [0.0, 1.33888, 3] + ?-CX-2C-?: + - [0.0, 1.301689, 3] + H1-CX-2C-OH: + - [0.0, 2.092, 1] + ?-CX-3C-?: + - [0.0, 1.301689, 3] + H1-CX-3C-OH: + - [0.0, 2.092, 1] + HO-OH-2C-H1: + - [0.0, 1.394667, 3] + HO-OH-3C-H1: + - [0.0, 1.394667, 3] + EP-S-S-2C: + - [0.0, 0.0, 0] + ?-S-2C-?: + - [0.0, 2.789333, 3] + ?-SH-2C-?: + - [0.0, 2.092, 3] + ?-2C-2C-?: + - [0.0, 1.301689, 3] + CX-2C-2C-HC: + - [0.0, 1.33888, 3] + HC-2C-2C-HC: + - [0.0, 1.2552, 3] + CT-2C-3C-HC: + - [0.0, 1.33888, 3] + CX-2C-3C-HC: + - [0.0, 1.33888, 3] + HC-2C-3C-CT: + - [0.0, 1.33888, 3] + HC-2C-3C-CX: + - [0.0, 1.33888, 3] + HC-2C-3C-HC: + - [0.0, 1.2552, 3] + C-CX-C8-C8: + - [0.0, 1.301689, 3] + N-CX-C8-C8: + - [0.0, 1.301689, 3] + N3-CX-C8-C8: + - [0.0, 1.301689, 3] + CX-C8-C8-C8: + - [180.0, 1.6736, 1] + - [180.0, 2.092, 2] + - [0.0, 1.50624, 3] + C8-C8-C8-N2: + - [0.0, 1.301689, 3] + C8-C8-N2-CA: + - [0.0, 0.0, 0] + C8-C8-C8-C8: + - [180.0, 1.6736, 1] + - [180.0, 2.092, 2] + - [0.0, 1.50624, 3] + C8-C8-C8-N3: + - [0.0, 1.301689, 3] + C8-N2-CA-N2: + - [180.0, 20.083201, 2] + H-N2-CA-N2: + - [180.0, 20.083201, 2] + C8-C8-N3-H: + - [0.0, 1.301689, 3] + C-CX-2C-SH: + - [180.0, 2.250992, 1] + - [180.0, 2.820016, 2] + - [0.0, 2.100368, 3] + - [0.0, 0.6276, 4] + N-CX-2C-SH: + - [0.0, 1.288672, 1] + - [180.0, 4.066848, 2] + - [0.0, 2.100368, 3] + - [0.0, 0.276144, 4] + N3-CX-2C-SH: + - [0.0, 1.288672, 1] + - [180.0, 4.066848, 2] + - [0.0, 2.100368, 3] + - [0.0, 0.276144, 4] + CX-2C-SH-HS: + - [0.0, 0.769856, 1] + - [0.0, 5.121216, 2] + - [0.0, 2.108736, 3] + - [0.0, 0.25104, 4] + C-CX-2C-CO: + - [0.0, 3.548032, 1] + - [180.0, 3.840912, 2] + - [0.0, 0.485344, 3] + - [0.0, 1.288672, 4] + N-CX-2C-CO: + - [180.0, 18.024672, 1] + - [180.0, 5.414096, 2] + - [0.0, 0.485344, 3] + - [0.0, 0.744752, 4] + N3-CX-2C-CO: + - [180.0, 18.024672, 1] + - [180.0, 5.414096, 2] + - [0.0, 0.485344, 3] + - [0.0, 0.744752, 4] + CX-2C-CO-O2: + - [180.0, 6.434992, 2] + - [180.0, 0.259408, 4] + C-CX-2C-C: + - [0.0, 8.752928, 1] + - [180.0, 2.535504, 2] + - [0.0, 0.276144, 3] + - [0.0, 0.895376, 4] + N-CX-2C-C: + - [180.0, 5.757184, 1] + - [180.0, 2.485296, 2] + - [0.0, 0.276144, 3] + - [0.0, 0.493712, 4] + N3-CX-2C-C: + - [180.0, 5.757184, 1] + - [180.0, 2.485296, 2] + - [0.0, 0.276144, 3] + - [0.0, 0.493712, 4] + CX-2C-C-O: + - [0.0, 0.0, 0] + CX-2C-C-OH: + - [180.0, 10.033232, 1] + - [180.0, 4.8116, 2] + - [0.0, 0.066944, 3] + - [180.0, 1.665232, 4] + 2C-C-OH-HO: + - [180.0, 3.748864, 1] + - [180.0, 22.643809, 2] + - [0.0, 4.008272, 3] + - [0.0, 0.945584, 4] + CX-2C-C-N: + - [180.0, 6.928704, 1] + - [180.0, 4.05848, 2] + - [180.0, 2.518768, 3] + - [0.0, 0.066944, 4] + C-CX-CT-CC: + - [180.0, 1.196624, 1] + - [180.0, 2.041792, 2] + - [0.0, 1.832592, 3] + - [0.0, 0.2092, 4] + N-CX-CT-CC: + - [180.0, 2.560608, 1] + - [180.0, 1.849328, 2] + - [0.0, 1.832592, 3] + - [0.0, 0.744752, 4] + N3-CX-CT-CC: + - [180.0, 2.560608, 1] + - [180.0, 1.849328, 2] + - [0.0, 1.832592, 3] + - [0.0, 0.744752, 4] + CX-CT-CC-NA: + - [180.0, 1.33888, 1] + - [180.0, 3.280256, 2] + - [0.0, 5.740448, 3] + - [180.0, 0.309616, 4] + CX-CT-CC-CV: + - [180.0, 5.640032, 1] + - [0.0, 6.276, 2] + - [180.0, 1.020896, 3] + - [180.0, 0.08368, 4] + CX-CT-CC-NB: + - [0.0, 5.77392, 1] + - [0.0, 1.707072, 2] + - [0.0, 6.19232, 3] + - [180.0, 0.393296, 4] + C-CX-3C-CT: + - [180.0, 3.397408, 1] + - [180.0, 2.418352, 2] + - [0.0, 1.238464, 3] + - [0.0, 0.937216, 4] + C-CX-3C-2C: + - [0.0, 1.355616, 1] + - [180.0, 6.15048, 2] + - [0.0, 0.945584, 3] + - [0.0, 0.96232, 4] + N-CX-3C-CT: + - [0.0, 2.820016, 1] + - [180.0, 1.807488, 2] + - [0.0, 1.238464, 3] + - [180.0, 0.008368, 4] + N-CX-3C-2C: + - [0.0, 2.59408, 1] + - [180.0, 1.204992, 2] + - [0.0, 0.945584, 3] + - [180.0, 0.811696, 4] + N3-CX-3C-CT: + - [0.0, 2.820016, 1] + - [180.0, 1.807488, 2] + - [0.0, 1.238464, 3] + - [180.0, 0.008368, 4] + N3-CX-3C-2C: + - [0.0, 2.59408, 1] + - [180.0, 1.204992, 2] + - [0.0, 0.945584, 3] + - [180.0, 0.811696, 4] + CX-3C-2C-CT: + - [0.0, 3.740496, 1] + - [0.0, 0.443504, 2] + - [0.0, 0.895376, 3] + - [0.0, 1.92464, 4] + CT-3C-2C-CT: + - [0.0, 1.690336, 1] + - [180.0, 0.644336, 2] + - [0.0, 0.895376, 3] + - [0.0, 1.874432, 4] + C-CX-3C-OH: + - [180.0, 5.832496, 1] + - [180.0, 0.995792, 2] + - [0.0, 2.63592, 3] + - [0.0, 1.305408, 4] + N-CX-3C-OH: + - [0.0, 5.640032, 1] + - [0.0, 0.050208, 2] + - [0.0, 2.63592, 3] + - [0.0, 0.79496, 4] + N3-CX-3C-OH: + - [0.0, 5.640032, 1] + - [0.0, 0.050208, 2] + - [0.0, 2.63592, 3] + - [0.0, 0.79496, 4] + CX-3C-OH-HO: + - [180.0, 0.050208, 1] + - [0.0, 2.100368, 2] + - [0.0, 1.974848, 3] + - [0.0, 0.108784, 4] + CT-3C-OH-HO: + - [0.0, 5.380624, 1] + - [180.0, 0.661072, 2] + - [0.0, 1.974848, 3] + - [0.0, 0.401664, 4] + C-CX-2C-3C: + - [0.0, 5.907808, 1] + - [180.0, 5.18816, 2] + - [0.0, 1.204992, 3] + - [0.0, 1.58992, 4] + N-CX-2C-3C: + - [0.0, 0.820064, 1] + - [180.0, 2.167312, 2] + - [0.0, 1.204992, 3] + - [0.0, 0.610864, 4] + N3-CX-2C-3C: + - [0.0, 0.820064, 1] + - [180.0, 2.167312, 2] + - [0.0, 1.204992, 3] + - [0.0, 0.610864, 4] + CX-2C-3C-CT: + - [0.0, 3.171472, 1] + - [180.0, 0.225936, 2] + - [0.0, 1.188256, 3] + - [0.0, 1.497872, 4] + C-CX-2C-OH: + - [180.0, 5.531248, 1] + - [180.0, 1.824224, 2] + - [0.0, 3.355568, 3] + - [0.0, 1.079472, 4] + N-CX-2C-OH: + - [0.0, 5.573088, 1] + - [180.0, 2.058528, 2] + - [0.0, 3.355568, 3] + - [0.0, 1.33888, 4] + N3-CX-2C-OH: + - [0.0, 5.573088, 1] + - [180.0, 2.058528, 2] + - [0.0, 3.355568, 3] + - [0.0, 1.33888, 4] + CX-2C-OH-HO: + - [0.0, 1.765648, 1] + - [0.0, 3.715392, 2] + - [0.0, 2.234256, 3] + - [0.0, 0.058576, 4] + C-CX-CT-C*: + - [180.0, 0.142256, 1] + - [180.0, 2.953904, 2] + - [0.0, 1.958112, 3] + - [0.0, 0.619232, 4] + N-CX-CT-C*: + - [0.0, 0.661072, 1] + - [180.0, 2.619184, 2] + - [0.0, 1.958112, 3] + - [0.0, 0.259408, 4] + N3-CX-CT-C*: + - [0.0, 0.661072, 1] + - [180.0, 2.619184, 2] + - [0.0, 1.958112, 3] + - [0.0, 0.259408, 4] + CX-CT-C*-CB: + - [0.0, 3.05432, 1] + - [0.0, 3.414144, 2] + - [0.0, 6.853392, 3] + - [180.0, 0.79496, 4] + CX-CT-C*-CW: + - [0.0, 0.0, 1] + C-CX-CT-CA: + - [0.0, 0.46024, 1] + - [180.0, 3.924592, 2] + - [0.0, 1.606656, 3] + - [180.0, 0.100416, 4] + N-CX-CT-CA: + - [180.0, 0.100416, 1] + - [180.0, 2.42672, 2] + - [0.0, 1.606656, 3] + - [180.0, 0.058576, 4] + N3-CX-CT-CA: + - [180.0, 0.100416, 1] + - [180.0, 2.42672, 2] + - [0.0, 1.606656, 3] + - [180.0, 0.058576, 4] + CX-CT-CA-CA: + - [180.0, 0.577392, 2] + - [180.0, 0.401664, 4] + CA-C-OH-HO: + - [180.0, 7.388944, 2] + - [0.0, 0.54392, 4] + C-CX-2C-2C: + - [180.0, 3.522928, 1] + - [180.0, 3.288624, 2] + - [0.0, 1.204992, 3] + - [0.0, 1.21336, 4] + N-CX-2C-2C: + - [180.0, 0.8368, 1] + - [180.0, 1.539712, 2] + - [0.0, 1.204992, 3] + - [0.0, 0.652704, 4] + N3-CX-2C-2C: + - [180.0, 0.8368, 1] + - [180.0, 1.539712, 2] + - [0.0, 1.204992, 3] + - [0.0, 0.652704, 4] + CX-2C-2C-C: + - [180.0, 1.640128, 1] + - [0.0, 0.694544, 2] + - [180.0, 3.447616, 3] + - [0.0, 1.154784, 4] + 2C-2C-C-O: + - [0.0, 0.0, 0] + 2C-2C-C-OH: + - [180.0, 6.895232, 1] + - [180.0, 9.238272, 2] + - [180.0, 0.2092, 3] + - [180.0, 0.552288, 4] + 2C-2C-C-N: + - [180.0, 5.096112, 1] + - [180.0, 7.07096, 2] + - [180.0, 0.71128, 3] + - [0.0, 0.351456, 4] + CX-2C-2C-CO: + - [180.0, 11.439056, 1] + - [180.0, 1.857696, 2] + - [180.0, 5.087744, 3] + - [180.0, 0.468608, 4] + 2C-2C-CO-O2: + - [180.0, 3.26352, 2] + - [0.0, 0.535552, 4] + CX-2C-2C-S: + - [0.0, 3.489456, 1] + - [0.0, 2.05016, 2] + - [0.0, 0.133888, 3] + - [0.0, 0.234304, 4] + 2C-2C-S-CT: + - [180.0, 2.066896, 1] + - [0.0, 3.698656, 2] + - [0.0, 3.464352, 3] + - [0.0, 0.476976, 4] + C-CX-2C-S: + - [0.0, 5.037536, 1] + - [180.0, 3.296992, 2] + - [0.0, 2.702864, 3] + - [0.0, 2.326304, 4] + N-CX-2C-S: + - [0.0, 3.924592, 1] + - [180.0, 0.175728, 2] + - [0.0, 2.702864, 3] + - [0.0, 0.535552, 4] + N3-CX-2C-S: + - [0.0, 3.924592, 1] + - [180.0, 0.175728, 2] + - [0.0, 2.702864, 3] + - [0.0, 0.535552, 4] + CX-2C-S-S: + - [0.0, 0.468608, 1] + - [0.0, 5.573088, 2] + - [0.0, 2.527136, 3] + - [180.0, 1.12968, 4] + 2C-S-S-2C: + - [0.0, 3.51456, 1] + - [0.0, 37.48864, 2] + - [0.0, 5.706976, 3] + - [0.0, 3.171472, 4] + improper_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: + - [phase, force_constant, periodicity] + parameters: + ?-?-C-O: + - [180.0, 87.864, 2] + ?-O2-C-O2: + - [180.0, 87.864, 2] + ?-?-N-H: + - [180.0, 8.368, 2] + ?-?-N2-H: + - [180.0, 8.368, 2] + ?-?-NA-H: + - [180.0, 8.368, 2] + ?-N2-CA-N2: + - [180.0, 87.864, 2] + ?-CT-N-CT: + - [180.0, 8.368, 2] + ?-CT-N-CX: + - [180.0, 8.368, 2] + ?-?-CA-HA: + - [180.0, 9.2048, 2] + ?-?-CW-H4: + - [180.0, 9.2048, 2] + ?-?-CR-H5: + - [180.0, 9.2048, 2] + ?-?-CV-H4: + - [180.0, 9.2048, 2] + ?-?-CQ-H5: + - [180.0, 9.2048, 2] + ?-?-CK-H5: + - [180.0, 9.2048, 2] + ?-?-CP-H5: + - [180.0, 9.2048, 2] + ?-?-CM-H4: + - [180.0, 9.2048, 2] + ?-?-CM-HA: + - [180.0, 9.2048, 2] + ?-?-CS-H4: + - [180.0, 9.2048, 2] + ?-?-CS-HA: + - [180.0, 9.2048, 2] + ?-?-CA-H4: + - [180.0, 9.2048, 2] + ?-?-CA-H5: + - [180.0, 9.2048, 2] + CB-CK-N*-CT: + - [180.0, 8.368, 2] + CB-CP-N*-CT: + - [180.0, 8.368, 2] + C-CM-N*-CT: + - [180.0, 8.368, 2] + C-CS-N*-CT: + - [180.0, 8.368, 2] + C-CS-CM-CT: + - [180.0, 9.2048, 2] + CT-O-C-OH: + - [180.0, 87.864, 2] + CT-CV-CC-NA: + - [180.0, 9.2048, 2] + CT-CW-CC-NB: + - [180.0, 9.2048, 2] + CT-CW-CC-NA: + - [180.0, 9.2048, 2] + CB-CT-C*-CW: + - [180.0, 9.2048, 2] + CA-CA-CA-CT: + - [180.0, 9.2048, 2] + C-CM-CM-CT: + - [180.0, 9.2048, 2] + C-CS-CS-CT: + - [180.0, 9.2048, 2] + CM-N2-CA-NC: + - [180.0, 9.2048, 2] + CS-N2-CA-NC: + - [180.0, 9.2048, 2] + CB-N2-CA-NC: + - [180.0, 9.2048, 2] + N2-NA-CA-NC: + - [180.0, 9.2048, 2] + CA-CA-C-OH: + - [180.0, 9.2048, 2] + CA-CA-CA-OH: + - [180.0, 9.2048, 2] + H5-O-C-OH: + - [180.0, 9.2048, 2] + H5-O-C-OS: + - [180.0, 9.2048, 2] + CM-CT-CM-HA: + - [180.0, 9.2048, 2] + CS-CT-CS-HA: + - [180.0, 9.2048, 2] + Br-CA-CA-CA: + - [180.0, 9.2048, 2] + CM-H4-C-O: + - [180.0, 9.2048, 2] + CS-H4-C-O: + - [180.0, 9.2048, 2] + C-CT-N-H: + - [180.0, 9.2048, 2] + C-CX-N-H: + - [180.0, 9.2048, 2] + C-CT-N-O: + - [180.0, 9.2048, 2] + ?-?-C5-H5: + - [180.0, 9.2048, 2] + CB-C5-N*-CT: + - [180.0, 8.368, 2] + ?-?-C4-H4: + - [180.0, 9.2048, 2] + ?-?-C4-HA: + - [180.0, 9.2048, 2] + C-C4-N*-CT: + - [180.0, 8.368, 2] + C-C4-C4-CT: + - [180.0, 9.2048, 2] + C4-N2-CA-NC: + - [180.0, 9.2048, 2] + CA-CA-C-OS: + - [180.0, 9.2048, 2] + CR-CC-NA-P: + - [180.0, 9.2048, 2] + ?-O2-CO-O2: + - [180.0, 87.864, 2] + 2C-O-C-OH: + - [180.0, 87.864, 2] + CA-CA-CA-2C: + - [180.0, 9.2048, 2] + coulomb_energy: + length_unit: nm + energy_unit: kj/mol + vdw_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [sigma, epsilon] + parameters: + H: [0.1069079, 0.0656888] + HO: [0.0, 0.0] + HS: [0.1069079, 0.0656888] + HC: [0.2649533, 0.0656888] + H1: [0.2471353, 0.0656888] + H2: [0.2293173, 0.0656888] + H3: [0.2114994, 0.0656888] + HP: [0.1959977, 0.0656888] + HA: [0.2599642, 0.06276] + H4: [0.2510553, 0.06276] + H5: [0.2421463, 0.06276] + HZ: [0.2599642, 0.06276] + O: [0.2959922, 0.87864] + O2: [0.2959922, 0.87864] + OH: [0.3066473, 0.8803136] + OS: [0.3000012, 0.71128] + OP: [0.3296325, 0.71128] + C*: [0.3399669, 0.359824] + CA: [0.3399669, 0.359824] + CC: [0.3399669, 0.359824] + CR: [0.3399669, 0.359824] + CW: [0.3399669, 0.359824] + CN: [0.3399669, 0.359824] + CB: [0.3399669, 0.359824] + CV: [0.3399669, 0.359824] + CI: [0.3399669, 0.4577296] + C5: [0.3399669, 0.359824] + C4: [0.3399669, 0.359824] + CT: [0.3399669, 0.4577296] + CX: [0.3399669, 0.4577296] + C: [0.3399669, 0.359824] + N: [0.3249999, 0.71128] + N2: [0.3249999, 0.71128] + N3: [0.3249999, 0.71128] + NA: [0.3249999, 0.71128] + NB: [0.3249999, 0.71128] + S: [0.3563595, 1.046] + SH: [0.3563595, 1.046] + P: [0.3741774, 0.8368] + MG: [0.1412253, 3.7434248] + C0: [0.3052397, 1.9237572] + F: [0.3118146, 0.255224] + Cl: [0.3470941, 1.1087599] + Br: [0.395559, 1.33888] + I: [0.4187224, 1.6736] + 2C: [0.3399669, 0.4577296] + 3C: [0.3399669, 0.4577296] + C8: [0.3399669, 0.4577296] + CO: [0.3399669, 0.359824] + HW: [0.0, 0.0] + OW: [0.636386, 0.315061] + nb_pair_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [r_scale, r6_scale, r12_scale] + parameters: + ?: [0.8333333, 0.5, 0.5] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/forcefield.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/forcefield.py new file mode 100644 index 0000000000000000000000000000000000000000..071ca67b40cd4ba61ee319e7e97ff425097ace15 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/forcefield.py @@ -0,0 +1,85 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Force field parameters +""" +import os +from typing import Union, Tuple + +from ..data import read_yaml, update_dict +from ..template import get_template + + +def get_forcefield(forcefield: Union[str, dict, list]) -> Tuple[dict, dict]: + """ + Get force field parameters from YAML file. + + Args: + forcefield (str, dict or list): The file name of force field parameters. + + Returns: + parameters (dict), Force field parameters. + template (dict), Molecular template. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + if forcefield is None: + return None, None + + if isinstance(forcefield, str): + if os.path.exists(forcefield): + filename = forcefield + else: + filename = forcefield.lower() + if os.path.splitext(forcefield)[-1] != '.yaml': + filename += '.yaml' + + directory, _ = os.path.split(os.path.realpath(__file__)) + filename = os.path.join(directory, filename) + if not os.path.exists(filename): + raise ValueError('Cannot find force field parameters file: "'+forcefield+'".') + + forcefield: dict = read_yaml(filename) + elif isinstance(forcefield, (list, tuple)): + parameters = {} + template = [] + for ff in forcefield: + params, temp = get_forcefield(ff) + template.append(temp) + parameters = update_dict(parameters, params) + template = get_template(template) + elif not isinstance(forcefield, dict): + raise TypeError('The type of forcefield must be str or dict but got: '+str(type(forcefield))) + + template = None + if 'template' in forcefield.keys(): + template = get_template(forcefield.pop('template')) + + if 'parameters' in forcefield.keys(): + parameters = forcefield.get('parameters') + else: + parameters = forcefield + + return parameters, template diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/spce.yaml b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/spce.yaml new file mode 100644 index 0000000000000000000000000000000000000000..089053523b7d93ac7ac3c63eba2fa038738f7967 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/spce.yaml @@ -0,0 +1,28 @@ +template: + base: water.spce.yaml +parameters: + bond_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [bond_length, force_constant] + parameters: + OW-HW: [0.1, 345000] + angle_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [bond_angle, force_constant] + parameters: + HW-OW-HW: [109.47, 383] + coulomb_energy: + length_unit: nm + energy_unit: kj/mol + vdw_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [sigma, epsilon] + parameters: + OW: [0.316557, 0.650629] + HW: [0.0, 0.0] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/tip3p.yaml b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/tip3p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04fc6c87fb4be09fb297028c7ebde7abeb57c1ad --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/forcefield/tip3p.yaml @@ -0,0 +1,28 @@ +template: + base: water.tip3p.yaml +parameters: + bond_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [bond_length, force_constant] + parameters: + OW-HW: [0.09572, 502416] + angle_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [bond_angle, force_constant] + parameters: + HW-OW-HW: [104.52, 628.02] + coulomb_energy: + length_unit: nm + energy_unit: kj/mol + vdw_energy: + length_unit: nm + energy_unit: kj/mol + parameter_names: + pattern: [sigma, epsilon] + parameters: + OW: [0.315061, 0.636386] + HW: [0.0, 0.0] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/hyperparam.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/hyperparam.py new file mode 100644 index 0000000000000000000000000000000000000000..9339b2a4c1cde45830647e26e503f72e8369e8d4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/hyperparam.py @@ -0,0 +1,304 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Functions for read and write hyperparameters in checkpoint file +""" + +import numpy as np +from mindspore import Tensor +from mindspore.nn import Cell, CellList +from mindspore.train import load_checkpoint +from ..function.functions import get_integer + + +def str_to_tensor(string: str) -> Tensor: + """ + encode string to Tensor[int] + + Args: + string (str): The input string. + + Returns: + Tensor[int]. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if isinstance(string, (list, tuple)): + string = ' '.join(string) + return Tensor(np.fromstring(string, dtype=np.int8)) + + +def tensor_to_str(tensor: Tensor) -> str: + """ + decode to Tensor[int] to string + + Args: + tensor (Tensor[int]): The input tensor. + + Returns: + string(str). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + tensor = Tensor(tensor).asnumpy() + string = tensor.tostring().decode() + string = string.split() + if len(string) == 1: + string = string[0] + return string + + +def get_class_parameters(hyper_param: dict, prefix: str, num_class: int = 1) -> dict: + """ + get hyperparameter from Cell class. + + Args: + hyper_param (dict): A dict of hyperparameters. + prefix (str): Only parameters starting with the prefix will be loaded. + num_class (int): The number of the class. Default: 1 + + Returns: + hyperparameters, dict. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def _get_class_parameters(hyper_param: dict, prefix: str) -> dict: + new_params = {} + idx = len(prefix) + 1 + for name, param in hyper_param.items(): + if name.find(prefix) == 0 \ + and (name == prefix or name[len(prefix)] == "." or (prefix and prefix[-1] == ".")): + new_params[name[idx:]] = param + if 'name' in new_params.keys(): + new_params['name'] = get_hyper_string(new_params, 'name') + if len(new_params) == 1: + new_params = new_params.get('name') + + if new_params: + return new_params + return None + + if num_class == 1: + return _get_class_parameters(hyper_param, prefix) + + param_list = [] + for i in range(num_class): + param_list.append(_get_class_parameters( + hyper_param, prefix+'.'+str(i))) + return param_list + + +def get_hyper_parameter(hyper_param: dict, prefix: str): + """ + get hyperparameter. + + Args: + hyper_param (dict): A dict of hyperparameters. + prefix (str): Only parameters starting with the prefix will be loaded. + + Returns: + hyper_param[prefix], Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if prefix in hyper_param.keys(): + return Tensor(hyper_param[prefix]) + return None + + +def get_hyper_string(hyper_param: dict, prefix: str): + """ + get string type hyperparameter. + + Args: + hyper_param (dict): A dict of hyperparameters. + prefix (str): Only parameters starting with the prefix will be loaded. + + Returns: + str. String type hyperparameter. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if prefix in hyper_param.keys(): + string = hyper_param[prefix] + if isinstance(string, str): + return string + return tensor_to_str(string) + return None + + +def set_hyper_parameter(hyper_param: dict, prefix: str, param: None): + """ + put param into hyper_param. + + Args: + hyper_param (dict): A dict of hyperparameters. + prefix (str): Only parameters starting with the prefix will be loaded. + param (Union[str, Tensor]): Parameters need to be put into the hyperparameter dict. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if param is None: + if prefix in hyper_param.keys(): + hyper_param.pop(prefix) + else: + if isinstance(param, str): + hyper_param[prefix] = str_to_tensor(param) + else: + hyper_param[prefix] = param + + +def set_class_parameters(hyper_param: list, prefix: str, cell: Cell): + """ + put hyperparameters into Cell class. + + Args: + hyper_param (dict): A dict of hyperparameters. + prefix (str): Only parameters starting with the prefix will be loaded. + cell (Cell): A neural network cell. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def _set_class_parameters(hyper_param: dict, prefix: str, cell: Cell): + if isinstance(cell, Cell): + if 'hyper_param' in cell.__dict__.keys(): + for key, param in cell.hyper_param.items(): + set_hyper_parameter(hyper_param, prefix+'.'+key, param) + else: + set_hyper_parameter(hyper_param, prefix + + '.name', cell.__class__.__name__) + elif isinstance(cell, str): + set_hyper_parameter(hyper_param, prefix, cell) + elif cell is not None: + raise TypeError('The type of "cls" must be "Cell", "str" or list of them, but got "' + + str(type(cell))+'".') + + if isinstance(cell, (CellList, list)): + for i, c in enumerate(cell): + _set_class_parameters(hyper_param, prefix+'.'+str(i), c) + else: + _set_class_parameters(hyper_param, prefix, cell) + + +def load_hyper_param_into_class(cls_dict: dict, hyper_param: dict, types: dict, prefix: str = ''): + """ + load hyperparameter into Cell class. + + Args: + cls_dict (dict): A dict of cls. + hyper_param (dict): A dict of hyperparameters. + types (dict): A dict of types of values. + prefix (str): Only parameters starting with the prefix will be loaded. Default: '' + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if prefix: + prefix = prefix + '.' + for key, value_type in types.items(): + if value_type == 'str': + cls_dict[key] = get_hyper_string(hyper_param, prefix+key) + elif value_type == 'int': + cls_dict[key] = get_integer(hyper_param[prefix+key]) + elif value_type == 'class': + num_class = 1 + num_key = 'num_' + key + if num_key in cls_dict.keys(): + num_class = get_integer(cls_dict[prefix+num_key]) + cls_dict[key] = num_class + cls_dict[key] = get_class_parameters( + hyper_param, prefix+key, num_class) + else: + cls_dict[key] = get_hyper_parameter(hyper_param, prefix+key) + + +def set_class_into_hyper_param(hyper_param: dict, types: dict, cls: Cell, prefix: str = ''): + """ + take hyperparameter from Cell class. + + Args: + hyper_param (dict): A dict of hyperparameters. + types (dict): A dict of types of values. + cls (Cell): A neural network cell. + prefix (str): Only parameters starting with the prefix will be loaded. Default: '' + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + #pylint: disable=protected-access + if prefix: + prefix = prefix + '.' + for key, value_type in types.items(): + if value_type == 'Cell': + if key in cls._cells.keys(): + if cls._cells[key] is not None: + set_class_parameters( + hyper_param, prefix+key, cls._cells[key]) + else: + if key in cls.__dict__.keys(): + set_hyper_parameter(hyper_param, prefix+key, cls.__dict__[key]) + elif key in cls._tensor_list.keys(): + set_hyper_parameter(hyper_param, prefix + + key, cls._tensor_list[key]) + + +def load_hyperparam(ckpt_file_name, prefix='hyperparam', dec_key=None, dec_mode="AES-GCM"): + """ + Load hyperparam from checkpoint file (.ckpt). + + Args: + ckpt_file_name (str): Checkpoint file name. + prefix (Union[str, list[str], tuple[str]]): Only parameters starting with the prefix + will be loaded. Default: 'hyperparam' + dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, + the decryption is not required. Default: None + dec_mode (str): This parameter is valid only when dec_key is not set to None. + Specifies the decryption mode, currently supports 'AES-GCM' + and 'AES-CBC'. Default: 'AES-GCM' + + Returns: + Dict, key is parameter name, value is a Parameter. + + Raises: + ValueError: Checkpoint file is incorrect. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindspore import load_hyperparam + >>> + >>> ckpt_file_name = "molct.ckpt" + >>> hyper_dict = load_hyperparam(ckpt_file_name, prefix="hyper") + >>> print(hyper_dict["hyper.dim_feature"]) + Tensor(shape=[1], dtype=Int8, value= [128]) + """ + + return load_checkpoint(ckpt_file_name, dec_key=dec_key, dec_mode=dec_mode, specify_prefix=prefix) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/parameters.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..08257657cc22ae1929bb57ff1d0c5b42b992f430 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/parameters.py @@ -0,0 +1,791 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Modeling Module. +""" +import re +from operator import itemgetter +from itertools import product +from pathlib import Path +from typing import NamedTuple +import numpy as np +from numpy import ndarray + +from .data import get_bonded_types, get_dihedral_types, get_improper_types + +this_directory = Path(__file__).parent + +backbone_atoms = np.array(["N", "CA", "C", "O"], np.str_) +include_backbone_atoms = np.array(["OXT"], np.str_) + + +class ForceConstants(NamedTuple): + """ The structured object for return force field parameters. + """ + bond_params: dict = None + angle_params: dict = None + dihedral_params: dict = None + improper_params: dict = None + angles: np.ndarray = None + dihedrals: np.ndarray = None + improper: np.ndarray = None + excludes: np.ndarray = None + vdw_param: dict = None + hbonds: np.ndarray = None + non_hbonds: np.ndarray = None + pair_params: dict = None + + +class ForceFieldParameters: + r""" + Getting parameters for given bonds and atom types. + + Args: + atom_types(str): The atom types defined in forcefields. + parameters(dict): A dictionary stores all force field constants. + atom_names(str): Unique atom names in an amino acid. Default: None + atom_charges(ndarray): The charge of the atoms. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, atom_types, parameters, atom_names=None, atom_charges=None): + self.atom_types = atom_types[0] + self.atom_names = atom_names[0] + atom_nums = atom_types.shape[-1] + assert atom_nums > 0 + self.atom_charges = atom_charges + self.atom_nums = atom_nums + + # Load force field parameters. + self.vdw_params = None + if 'vdw_energy' in parameters.keys(): + self.vdw_params = parameters["vdw_energy"] + + self.bond_params = None + if 'bond_energy' in parameters.keys(): + self.bond_params = parameters["bond_energy"] + + self.angle_params = None + if 'angle_energy' in parameters.keys(): + self.angle_params = parameters["angle_energy"] + + self._dihedrals = None + if 'dihedral_energy' in parameters.keys(): + self._dihedrals = parameters["dihedral_energy"] + + self._improper = None + if 'improper_energy' in parameters.keys(): + self._improper = parameters["improper_energy"] + + self.pair_params = None + if 'nb_pair_energy' in parameters.keys(): + self.pair_params = parameters["nb_pair_energy"] + + self._wildcard = np.array(["X"], dtype=np.str_) + + self.htypes = np.array( + ["H", "HC", "H1", "HS", "H5", "H4", "HP", "HA", "HO"], np.str_ + ) + + self.dihedral_params = None + self.improper_params = None + self.excludes = np.empty(atom_nums)[:, None] + self.vdw_param = {} + self.pair_index = None + + def get_bond_params(self, bonds, atom_type): + """ + Get the force field bond parameters. + + Args: + bonds (ndarray): Array of bonds between two atoms. + atom_type (ndarray): Array of the types of atoms. + + Returns: + dict, params. + """ + bond_atoms = np.take(atom_type, bonds, -1) + + k_index = self.bond_params['parameter_names']["pattern"].index('force_constant') + r_index = self.bond_params['parameter_names']["pattern"].index('bond_length') + + bond_params: dict = self.bond_params['parameters'] + params = {} + for k, v in bond_params.items(): + [a, b] = k.split('-') + if a != b: + params[b + '-' + a] = v + bond_params.update(params) + + bond_type = get_bonded_types(bond_atoms) + type_list: list = bond_type.reshape(-1).tolist() + + if len(type_list) == 1: + bond_length = [bond_params[type_list[0]][r_index]] + force_constant = [bond_params[type_list[0]][k_index]] + else: + bond_length = [] + force_constant = [] + for params in itemgetter(*type_list)(bond_params): + bond_length.append(params[r_index]) + force_constant.append(params[k_index]) + + params = {'bond_index': bonds} + params['force_constant'] = np.array(force_constant, np.float32).reshape(bond_type.shape) + params['bond_length'] = np.array(bond_length, np.float32).reshape(bond_type.shape) + + return params + + def get_angle_params(self, angles, atom_type): + """ + Get the force field angle parameters. + + Args: + angles (ndarray): Array of angles. + atom_type (ndarray): Array of the types of atoms. + + Returns: + dict, params. + """ + angle_atoms = np.take(atom_type, angles, -1) + + k_index = self.angle_params['parameter_names']["pattern"].index('force_constant') + t_index = self.angle_params['parameter_names']["pattern"].index('bond_angle') + + angle_params: dict = self.angle_params['parameters'] + params = {} + for k, v in angle_params.items(): + [a, b, c] = k.split('-') + if a != c: + params[c + '-' + b + '-' + a] = v + angle_params.update(params) + + angle_type = get_bonded_types(angle_atoms) + type_list: list = angle_type.reshape(-1).tolist() + + if len(type_list) == 1: + bond_angle = [angle_params[type_list[0]][t_index]] + force_constant = [angle_params[type_list[0]][k_index]] + else: + bond_angle = [] + force_constant = [] + for params in itemgetter(*type_list)(angle_params): + bond_angle.append(params[t_index]) + force_constant.append(params[k_index]) + + params = {'angle_index': angles} + params['force_constant'] = np.array(force_constant, np.float32).reshape(angle_type.shape) + params['bond_angle'] = np.array(bond_angle, np.float32).reshape(angle_type.shape) / 180 * np.pi + + return params + + def get_dihedral_params(self, dihedrals_in, atom_types): + """ + Get the force field dihedral parameters. + + Args: + dihedrals_in (ndarray): Array of input dihedrals. + atom_type (ndarray): Array of the types of atoms. + + Returns: + dict, params. + """ + dihedral_atoms = np.take(atom_types, dihedrals_in, -1) + + k_index = self._dihedrals['parameter_names']["pattern"][0].index('force_constant') + phi_index = self._dihedrals['parameter_names']["pattern"][0].index('phase') + t_index = self._dihedrals['parameter_names']["pattern"][0].index('periodicity') + + dihedral_params: dict = self._dihedrals['parameters'] + + key_types_ndarray = np.array([specific_name.split('-') for specific_name in dihedral_params.keys()], np.str_) + types_sorted_args = np.argsort((key_types_ndarray == '?').sum(axis=-1)) + sorted_key_types = key_types_ndarray[types_sorted_args] + transformed_key_types = ['-'.join(specific_name).replace('?', '.+').replace('*', '\\*') for + specific_name in sorted_key_types] + + dihedral_types, inverse_dihedral_types = get_dihedral_types(dihedral_atoms) + type_list: list = dihedral_types.reshape(-1).tolist() + inverse_type_list: list = inverse_dihedral_types.reshape(-1).tolist() + + for i, _ in enumerate(type_list): + for key_type in transformed_key_types: + if re.match('^'+key_type+'$', type_list[i]) or re.match('^'+key_type+'$', inverse_type_list[i]): + type_list[i] = key_type.replace('.+', '?').replace('\\', '') + break + + force_constant = [] + phase = [] + periodicity = [] + dihedral_index = [] + for i, params in enumerate(itemgetter(*type_list)(dihedral_params)): + for _, lastd_params in enumerate(params): + dihedral_index.append(dihedrals_in[i]) + force_constant.append(lastd_params[k_index]) + phase.append(lastd_params[phi_index]) + periodicity.append(lastd_params[t_index]) + + params = {} + params['force_constant'] = np.array(force_constant, np.float32) + ks0_filter = np.where(params['force_constant'] != 0)[0] + params['force_constant'] = params['force_constant'][ks0_filter] + params['dihedral_index'] = np.array(dihedral_index, np.int32)[ks0_filter] + params['phase'] = np.array(phase, np.float32)[ks0_filter] / 180 * np.pi + params['periodicity'] = np.array(periodicity, np.float32)[ks0_filter] + + return params + + def get_improper_params(self, improper_in, atom_types, third_id): + """ + Pre-processing of getting improper dihedrals. + + Args: + improper_in (ndarray): Array of input improper dihedrals. + atom_types (ndarray): Array of the types of atoms. + third_id (ndarray): Array of the third IDs. + + Returns: + dict, params. + """ + improper_atoms = np.take(atom_types, improper_in, -1) + + k_index = self._improper['parameter_names']["pattern"][0].index('force_constant') + phi_index = self._improper['parameter_names']["pattern"][0].index('phase') + t_index = self._improper['parameter_names']["pattern"][0].index('periodicity') + + improper_params: dict = self._improper['parameters'] + + key_types_ndarray = np.array([specific_name.split('-') for specific_name in improper_params.keys()], np.str_) + types_sorted_args = np.argsort((key_types_ndarray == '?').sum(axis=-1)) + sorted_key_types = key_types_ndarray[types_sorted_args] + transformed_key_types = ['-'.join(specific_name).replace('?', '.+').replace('*', '\\*') for specific_name in + sorted_key_types] + + improper_types, orders = get_improper_types(improper_atoms) + type_list = improper_types[0].reshape(-1) + + not_defined_mask = np.zeros(type_list.shape).astype(np.int32) + for i, _ in enumerate(type_list): + for key_type in transformed_key_types: + for j, itypes in enumerate(improper_types): + if re.match('^'+key_type+'$', itypes[i]): + this_improper = improper_in[i][np.array(list(orders[j]))] + if this_improper[2] != third_id[i]: + continue + improper_in[i] = this_improper + not_defined_mask[i] = 1 + type_list[i] = key_type.replace('.+', '?').replace('\\', '') + break + else: + continue + break + + type_list = type_list[np.where(not_defined_mask > 0)[0]] + + force_constant = [] + phase = [] + periodicity = [] + improper_index = [] + improper = improper_in[np.where(not_defined_mask > 0)[0]] + for i, params in enumerate(itemgetter(*type_list)(improper_params)): + for _, lastd_params in enumerate(params): + improper_index.append(improper[i]) + force_constant.append(lastd_params[k_index]) + phase.append(lastd_params[phi_index]) + periodicity.append(lastd_params[t_index]) + + params = {'improper_index': np.array(improper_index, np.int32)} + params['force_constant'] = np.array(force_constant, np.float32) + params['phase'] = np.array(phase, np.float32) / 180 * np.pi + params['periodicity'] = np.array(periodicity, np.float32) + + return params + + def construct_angles(self, bonds, bonds_for_angle, middle_id): + for idx in middle_id: + this_bonds = bonds[np.where(bonds_for_angle == idx)[0]] + flatten_bonds = this_bonds.flatten() + this_idx = np.delete(flatten_bonds, np.where(flatten_bonds == idx)) + yield this_idx + + def combinations(self, bonds, bonds_for_angle, middle_id): + """ + Get all the combinations of 3 atoms. + + Args: + bonds (ndarray): Array of bonds. + bonds_for_angle (ndarray): Array of bonds for angles. + middle_id (ndarray): Array of middle IDs. + + Returns: + np.ndarray, angles. + """ + this_idx = self.construct_angles(bonds, bonds_for_angle, middle_id) + id_selections = [ + [[0, 1]], + [[0, 1], [1, 2], [0, 2]], + [[0, 1], [1, 2], [2, 3], [0, 2], [0, 3], [1, 3]], + ] + angles = None + counter = 0 + for idx in this_idx: + selections = id_selections[idx.size - 2] + for selection in selections: + if angles is None: + angles = np.insert(idx[selection], 1, middle_id[counter])[ + None, :] + else: + angles = np.append( + angles, + np.insert(idx[selection], 1, middle_id[counter])[ + None, :], + axis=0, + ) + counter += 1 + return angles + + def construct_hash(self, bonds): + """ + Args: + bonds (ndarray): Array of bonds. + + Returns: + dict, hash map. + """ + hash_map = {} + for i, b in enumerate(bonds): + bond = tuple(b) + hash_map[hash(bond)] = i + return hash_map + + def trans_dangles(self, dangles, middle_id): + """ + Construct the dihedrals. + + Args: + dangles (ndarray): Array of dangles. + middle_id (ndarray): Array of middle IDs. + + Returns: + np.ndarray, dihedrals. + """ + left_id = np.isin(dangles[:, 0], middle_id[0]) + left_ele = dangles[:, 2][left_id] + left_id = np.isin(dangles[:, 2], middle_id[0]) + left_ele = np.append(left_ele, dangles[:, 0][left_id]) + right_id = np.isin(dangles[:, 1], middle_id[0]) + right_ele = np.unique(dangles[right_id]) + right_ele = right_ele[np.where( + np.isin(right_ele, middle_id, invert=True))[0]] + sides = product(right_ele, left_ele) + sides_array = np.array(list(sides)) + + if sides_array.size == 0: + return sides_array + + sides = sides_array[np.where( + sides_array[:, 0] != sides_array[:, 1])[0]] + left = np.append( + sides[:, 0].reshape(sides.shape[0], 1), + np.broadcast_to(middle_id, (sides.shape[0],) + middle_id.shape), + axis=1, + ) + dihedrals = np.append( + left, sides[:, 1].reshape(sides.shape[0], 1), axis=1) + return dihedrals + + def get_dihedrals(self, angles, dihedral_middle_id): + """ + Get the dihedrals indexes. + + Args: + angles (ndarray): Array of angles. + dihedral_middle_id (ndarray): Array of dihedrals middle indexes. + + Returns: + np.ndarray, dihedrals. + """ + dihedrals = None + for i in range(dihedral_middle_id.shape[0]): + dangles = angles[ + np.where( + ( + np.isin(angles, dihedral_middle_id[i]).sum(axis=1) + * np.isin(angles[:, 1], dihedral_middle_id[i]) + ) + > 1 + )[0] + ] + this_sides = self.trans_dangles(dangles, dihedral_middle_id[i]) + if this_sides.size == 0: + continue + if dihedrals is None: + dihedrals = this_sides + else: + dihedrals = np.append(dihedrals, this_sides, axis=0) + return dihedrals + + def check_improper(self, bonds, core_id): + """ + Check if there are same improper dihedrals. + + Args: + bonds (ndarray): Array of bonds. + core_id (ndarray): Array of core indexes. + + Returns: + int, core id of same improper dihedrals. + """ + # pylint: disable=pointless-statement + checked_core_id = core_id.copy() + bonds_hash = [hash(tuple(x)) for x in bonds] + for i in range(core_id.shape[0]): + ids_for_idihedral = np.where( + np.sum(np.isin(bonds, core_id[i]), axis=1) > 0 + )[0] + bonds_for_idihedral = bonds[ids_for_idihedral] + uniques = np.unique(bonds_for_idihedral.flatten()) + uniques = np.delete(uniques, np.where(uniques == core_id[i])[0]) + uniques_product = np.array(list(product(uniques, uniques))) + uniques_hash = np.array([hash(tuple(x)) for x in product(uniques, uniques)]) + excludes = np.isin(uniques_hash, bonds_hash) + exclude_size = np.unique(uniques_product[excludes]).size + # Exclude condition + if uniques.shape[0] - excludes.sum() <= 2 or exclude_size > 3: + checked_core_id[i] == -1 + return checked_core_id[np.where(checked_core_id > -1)[0]] + + def get_improper(self, bonds, core_id): + """ + Get the improper dihedrals indexes. + + Args: + bonds (ndarray): Array of bonds. + core_id (ndarray): Array of core indexes. + + Returns: + - improper (np.ndarray). + - new_id (np.ndarray). + """ + improper = None + new_id = None + for i in range(core_id.shape[0]): + ids_for_idihedral = np.where( + np.sum(np.isin(bonds, core_id[i]), axis=1) > 0 + )[0] + bonds_for_idihedral = bonds[ids_for_idihedral] + if bonds_for_idihedral.shape[0] == 3: + idihedral = np.unique(bonds_for_idihedral.flatten())[None, :] + if improper is None: + improper = idihedral + new_id = core_id[i] + else: + improper = np.append(improper, idihedral, axis=0) + new_id = np.append(new_id, core_id[i]) + else: + # Only SP2 is considered. + continue + return improper, new_id + + def get_excludes(self, bonds, angles, dihedrals, improper): + """ + Get the exclude atoms index. + + Args: + bonds (ndarray): Array of bonds. + angles (ndarray): Array of angles. + dihedrals (ndarray): Array of dihedrals. + improper (ndarray): Array of improper. + + Returns: + np.ndarray, the index of exclude atoms. + """ + excludes = [] + for i in range(self.atom_nums): + bond_excludes = bonds[np.where( + np.isin(bonds, i).sum(axis=1))[0]].flatten() + this_excludes = bond_excludes + + if angles is not None: + angle_excludes = angles[ + np.where(np.isin(angles, i).sum(axis=1))[0] + ].flatten() + this_excludes = np.append(this_excludes, angle_excludes) + + if dihedrals is not None: + dihedral_excludes = dihedrals[ + np.where(np.isin(dihedrals, i).sum(axis=1))[0] + ].flatten() + this_excludes = np.append(this_excludes, dihedral_excludes) + if improper is not None: + idihedral_excludes = improper[ + np.where(np.isin(improper, i).sum(axis=1))[0] + ].flatten() + this_excludes = np.append(this_excludes, idihedral_excludes) + + this_excludes = np.unique(this_excludes) + excludes.append(this_excludes[np.where( + this_excludes != i)[0]].tolist()) + padding_length = 0 + for i in range(self.atom_nums): + padding_length = max(padding_length, len(excludes[i])) + self.excludes = np.empty((self.atom_nums, padding_length)) + for i in range(self.atom_nums): + self.excludes[i] = np.pad( + np.array(excludes[i]), + (0, padding_length - len(excludes[i])), + mode="constant", + constant_values=self.atom_nums, + ) + return self.excludes + + def get_vdw_params(self, atom_type: ndarray): + """ + ['H','HO','HS','HC','H1','H2','H3','HP','HA','H4', + 'H5','HZ','O','O2','OH','OS','OP','C*','CI','C5', + 'C4','CT','CX','C','N','N3','S','SH','P','MG', + 'C0','F','Cl','Br','I','2C','3C','C8','CO'] + + Args: + atom_type (ndarray): Array of atoms. + + Returns: + dict, parameters. + """ + + sigma_index = self.vdw_params['parameter_names']["pattern"].index('sigma') + eps_index = self.vdw_params['parameter_names']["pattern"].index('epsilon') + + vdw_params = self.vdw_params['parameters'] + type_list: list = atom_type.reshape(-1).tolist() + sigma = [] + epsilon = [] + for params in itemgetter(*type_list)(vdw_params): + sigma.append(params[sigma_index]) + epsilon.append(params[eps_index]) + + if atom_type.ndim == 2 and atom_type.shape[0] > 1: + #TODO + type_list: list = atom_type[0].tolist() + + type_set = list(set(type_list)) + count = np.array([type_list.count(i) for i in type_set], np.int32) + + sigma_set = [] + eps_set = [] + for params in itemgetter(*type_set)(vdw_params): + sigma_set.append(params[sigma_index]) + eps_set.append(params[eps_index]) + + sigma_set = np.array(sigma_set) + eps_set = np.array(eps_set) + c6_set = 4 * eps_set * np.power(sigma_set, 6) + param_count = count.reshape(1, -1) * count.reshape(-1, 1) - np.diag(count) + mean_c6 = np.sum(c6_set * param_count) / param_count.sum() + + params = {} + params['sigma'] = np.array(sigma, np.float32).reshape(atom_type.shape) + params['epsilon'] = np.array(epsilon, np.float32).reshape(atom_type.shape) + params['mean_c6'] = mean_c6.astype(np.float32) + + return params + + def get_pairwise_c6(self, e0, e1, r0, r1): + """ + Calculate the B coefficient in vdw potential. + + Args: + e0 (ndarray): Coefficient one. + e1 (ndarray): Coefficient two. + r0 (ndarray): Coefficient three. + r1 (ndarray): Coefficient four. + + Returns: + np.ndarray, the B coefficient in vdw potential. + """ + e01 = np.sqrt(e0 * e1) + r01 = r0 + r1 + return 2 * e01 * r01 ** 6 + + def get_hbonds(self, bonds): + """ + Get hydrogen bonds. + + Args: + atom_type (ndarray): Array of atoms. + + Returns: + - bonds (np.ndarray), bonds with H. + - bonds (np.ndarray), non H bonds. + """ + hatoms = np.where(np.isin(self.atom_types, self.htypes))[0] + bonds_with_h = np.where(np.isin(bonds, hatoms).sum(axis=-1))[0] + non_hbonds = np.where(np.isin(bonds, hatoms).sum(axis=-1) == 0)[0] + return bonds[bonds_with_h], bonds[non_hbonds] + + def get_pair_index(self, dihedrals, angles, bonds): + """ + Get the non-bonded atom pairs index. + + Args: + dihedrals (ndarray): Array of dihedrals. + angles (ndarray): Array of angles. + bonds (ndarray): Array of bonds. + + Returns: + np.ndarray, non-bonded atom pairs index. + """ + pairs = dihedrals[:, [0, -1]] + pairs.sort() + pair_index = np.unique(pairs, axis=0) + pair_hash = [] + for pair in pair_index: + if pair[0] < pair[1]: + pair_hash.append(hash((pair[0], pair[1]))) + else: + pair_hash.append(hash((pair[1], pair[0]))) + pair_hash = np.array(pair_hash) + angle_hash = [] + for angle in angles: + if angle[0] < angle[-1]: + angle_hash.append(hash((angle[0], angle[-1]))) + else: + angle_hash.append(hash((angle[-1], angle[0]))) + angle_hash = np.array(angle_hash) + bond_hash = [] + for bond in bonds: + b = sorted(bond) + bond_hash.append(hash(tuple(b))) + bond_hash = np.array(bond_hash) + include_index = np.where( + np.isin(pair_hash, angle_hash) + np.isin(pair_hash, bond_hash) == 0 + )[0] + return pair_index[include_index] + + def get_pair_params(self, pair_index, epsilon, sigma): + """ + Return all the pair parameters. + + Args: + pair_index (ndarray): Array of pair indexes. + epsilon (ndarray): Array of epsilon. + sigma (ndarray): Array of sigma. + + Returns: + dict, pair parameters. + """ + + r_index = self.pair_params['parameter_names']["pattern"].index('r_scale') + r6_index = self.pair_params['parameter_names']["pattern"].index('r6_scale') + r12_index = self.pair_params['parameter_names']["pattern"].index('r12_scale') + + pair_params = self.pair_params['parameters'] + if len(pair_params) == 1 and '?' in pair_params.keys(): + r_scale = pair_params['?'][r_index] + r6_scale = pair_params['?'][r6_index] + r12_scale = pair_params['?'][r12_index] + else: + #TODO + r_scale = 0 + r6_scale = 0 + r12_scale = 0 + + qiqj = np.take_along_axis(self.atom_charges, pair_index, axis=1) + qiqj = np.prod(qiqj, -1) + + epsilon_ij = epsilon[pair_index] + epsilon_ij = np.sqrt(np.prod(epsilon_ij, -1)) + + sigma_ij = sigma[pair_index] + sigma_ij = np.mean(sigma_ij, -1) + + pair_params = {} + pair_params['qiqj'] = qiqj + pair_params['epsilon_ij'] = epsilon_ij + pair_params['sigma_ij'] = sigma_ij + pair_params['r_scale'] = r_scale + pair_params['r6_scale'] = r6_scale + pair_params['r12_scale'] = r12_scale + + return pair_params + + def __call__(self, bonds): + # pylint: disable=unused-argument + bonds = bonds[0] + atoms_types = self.atom_types.copy() + vdw_params = self.get_vdw_params(atoms_types) + atom_types = np.append(atoms_types, self._wildcard) + + bond_params = None + angle_params = None + if bonds is not None: + hbonds, non_hbonds = self.get_hbonds(bonds) + bond_params = self.get_bond_params(bonds, atoms_types) + middle_id = np.where(np.bincount(bonds.flatten()) > 1)[0] + ids_for_angle = np.where( + np.sum(np.isin(bonds, middle_id), axis=1) > 0)[0] + bonds_for_angle = bonds[ids_for_angle] + angles = self.combinations(bonds, bonds_for_angle, middle_id) + + if angles is not None: + angle_params = self.get_angle_params(angles, atoms_types) + dihedral_middle_id = bonds[ + np.where(np.isin(bonds, middle_id).sum(axis=1) == 2)[0] + ] + dihedrals = self.get_dihedrals(angles, dihedral_middle_id) + dihedral_params = None + if dihedrals is not None: + dihedral_params = self.get_dihedral_params(dihedrals, atom_types) + core_id = np.where(np.bincount(bonds.flatten()) > 2)[0] + improper = None + improper_params = None + if self._improper is not None: + checked_core_id = self.check_improper(bonds, core_id) + improper, third_id = self.get_improper(bonds, checked_core_id) + improper_params = self.get_improper_params(improper, atom_types, third_id) + if dihedrals is not None: + self.pair_index = self.get_pair_index(dihedrals, angles, bonds) + pair_params = self.get_pair_params(self.pair_index, vdw_params['epsilon'], + vdw_params['sigma']) + else: + pair_params = None + self.excludes = self.get_excludes(bonds, angles, dihedrals, improper) + + return ForceConstants( + bond_params, + angle_params, + dihedral_params, + improper_params, + angles, + dihedrals, + improper, + self.excludes, + vdw_params, + hbonds, + non_hbonds, + pair_params, + ) + + return ForceConstants(excludes=self.excludes, vdw_param=self.vdw_param) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d013ca6b53a4d28070f4b2ced91a06bc8246aaf8 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Molecular templates +""" + +from .template import get_template, get_template_index, get_molecule + +__all__ = ['get_template', 'get_template_index', 'get_molecule'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/protein0.yaml b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/protein0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..262a727d6edade06810f6f6af38bc3ad2dd50a0b --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/protein0.yaml @@ -0,0 +1,665 @@ +template: + ALA: + atom_name: [N, H, CA, HA, CB, HB1, HB2, HB3, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 8], [4, 5], [4, 6], [4, 7], [8, 9]] + head_atom: 0 + tail_atom: 8 + ARG: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, HD2, HD3, NE, HE, CZ, + NH1, HH11, HH12, NH2, HH21, HH22, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 1.008, 1.008, 14.01, 1.008, 12.01, 14.01, 1.008, 1.008, 14.01, 1.008, + 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 7, 1, 6, 7, 1, 1, 7, 1, + 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 22], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [10, 13], [13, 14], [13, 15], [15, 16], + [15, 19], [16, 17], [16, 18], [19, 20], [19, 21], [22, 23]] + head_atom: 0 + tail_atom: 22 + ASN: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, OD1, ND2, HD21, HD22, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 16.0, 14.01, + 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 8, 7, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 12], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [9, 10], [9, 11], [12, 13]] + head_atom: 0 + tail_atom: 12 + ASP: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, OD1, OD2, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 16.0, 16.0, + 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 8, 8, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 10], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [10, 11]] + head_atom: 0 + tail_atom: 10 + CYS: + atom_name: [N, H, CA, HA, CB, HB2, HB3, SG, HG, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 32.06, 1.008, 12.01, + 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 16, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 9], [4, 5], [4, 6], [4, 7], [7, 8], + [9, 10]] + head_atom: 0 + tail_atom: 9 + GLN: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, OE1, NE2, HE21, HE22, + C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 16.0, 14.01, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 8, 7, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 15], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [12, 13], [12, 14], [15, 16]] + head_atom: 0 + tail_atom: 15 + GLU: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, OE1, OE2, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 16.0, 16.0, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 8, 8, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 13], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [13, 14]] + head_atom: 0 + tail_atom: 13 + GLY: + atom_name: [N, H, CA, HA2, HA3, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 5], [5, 6]] + head_atom: 0 + tail_atom: 5 + HID: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, ND1, HD1, CE1, HE1, NE2, CD2, HD2, + C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 14.01, 1.008, + 12.01, 1.008, 14.01, 12.01, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 7, 1, 6, 1, 7, 6, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 15], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 13], [8, 9], [8, 10], [10, 11], [10, 12], [12, 13], [13, 14], [15, 16]] + head_atom: 0 + tail_atom: 15 + HIS: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, ND1, CE1, HE1, NE2, HE2, CD2, HD2, + C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 14.01, 12.01, + 1.008, 14.01, 1.008, 12.01, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 7, 6, 1, 7, 1, 6, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 15], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 13], [8, 9], [9, 10], [9, 11], [11, 12], [11, 13], [13, 14], [15, 16]] + head_atom: 0 + tail_atom: 15 + ILE: + atom_name: [N, H, CA, HA, CB, HB, CG2, HG21, HG22, HG23, CG1, HG12, HG13, CD1, + HD11, HD12, HD13, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, + 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 17], [4, 5], [4, 6], [4, 10], [6, + 7], [6, 8], [6, 9], [10, 11], [10, 12], [10, 13], [13, 14], [13, 15], [13, 16], + [17, 18]] + head_atom: 0 + tail_atom: 17 + LEU: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG, CD1, HD11, HD12, HD13, CD2, HD21, + HD22, HD23, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 12.01, + 1.008, 1.008, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 17], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 13], [9, 10], [9, 11], [9, 12], [13, 14], [13, 15], [13, 16], [17, + 18]] + head_atom: 0 + tail_atom: 17 + LYS: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, HD2, HD3, CE, HE2, + HE3, NZ, HZ1, HZ2, HZ3, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 14.01, 1.008, 1.008, 1.008, 12.01, + 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 7, 1, 1, 1, 6, + 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 20], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [10, 13], [13, 14], [13, 15], [13, 16], + [16, 17], [16, 18], [16, 19], [20, 21]] + head_atom: 0 + tail_atom: 20 + MET: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, SD, CE, HE1, HE2, HE3, + C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 32.06, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 16, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 15], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [11, 12], [11, 13], [11, 14], [15, 16]] + head_atom: 0 + tail_atom: 15 + PHE: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, CD1, HD1, CE1, HE1, CZ, HZ, CE2, + HE2, CD2, HD2, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 12.01, 1.008, + 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 18], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 16], [8, 9], [8, 10], [10, 11], [10, 12], [12, 13], [12, 14], [14, 15], + [14, 16], [16, 17], [18, 19]] + head_atom: 0 + tail_atom: 18 + PRO: + atom_name: [N, CD, HD2, HD3, CG, HG2, HG3, CB, HB2, HB3, CA, HA, C, O] + atom_mass: [14.01, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 1.008, 12.01, 16.0] + atomic_number: [7, 6, 1, 1, 6, 1, 1, 6, 1, 1, 6, 1, 6, 8] + bond: [[0, 1], [0, 10], [1, 2], [1, 3], [1, 4], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [12, 13]] + head_atom: 0 + tail_atom: 12 + SER: + atom_name: [N, H, CA, HA, CB, HB2, HB3, OG, HG, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 16.0, 1.008, 12.01, + 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 8, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 9], [4, 5], [4, 6], [4, 7], [7, 8], + [9, 10]] + head_atom: 0 + tail_atom: 9 + THR: + atom_name: [N, H, CA, HA, CB, HB, CG2, HG21, HG22, HG23, OG1, HG1, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, + 16.0, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 6, 1, 1, 1, 8, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 12], [4, 5], [4, 6], [4, 10], [6, + 7], [6, 8], [6, 9], [10, 11], [12, 13]] + head_atom: 0 + tail_atom: 12 + TRP: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, CD1, HD1, NE1, HE1, CE2, CZ2, HZ2, + CH2, HH2, CZ3, HZ3, CE3, HE3, CD2, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 12.01, 1.008, + 14.01, 1.008, 12.01, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, + 12.01, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 6, 1, 7, 1, 6, 6, 1, 6, 1, 6, 1, 6, 1, + 6, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 22], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 21], [8, 9], [8, 10], [10, 11], [10, 12], [12, 13], [12, 21], [13, 14], + [13, 15], [15, 16], [15, 17], [17, 18], [17, 19], [19, 20], [19, 21], [22, 23]] + head_atom: 0 + tail_atom: 22 + TYR: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, CD1, HD1, CE1, HE1, CZ, OH, HH, CE2, + HE2, CD2, HD2, C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 12.01, 1.008, + 12.01, 1.008, 12.01, 16.0, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 6, 1, 6, 1, 6, 8, 1, 6, 1, 6, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 19], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 17], [8, 9], [8, 10], [10, 11], [10, 12], [12, 13], [12, 15], [13, 14], + [15, 16], [15, 17], [17, 18], [19, 20]] + head_atom: 0 + tail_atom: 19 + VAL: + atom_name: [N, H, CA, HA, CB, HB, CG1, HG11, HG12, HG13, CG2, HG21, HG22, HG23, + C, O] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, + 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 14], [4, 5], [4, 6], [4, 10], [6, + 7], [6, 8], [6, 9], [10, 11], [10, 12], [10, 13], [14, 15]] + head_atom: 0 + tail_atom: 14 + NALA: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB1, HB2, HB3, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, + 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 10], [6, 7], [6, 8], + [6, 9], [10, 11]] + head_atom: null + tail_atom: 10 + NARG: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, HD2, HD3, + NE, HE, CZ, NH1, HH11, HH12, NH2, HH21, HH22, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 1.008, 1.008, 12.01, 1.008, 1.008, 14.01, 1.008, 12.01, 14.01, 1.008, 1.008, + 14.01, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 7, 1, 6, 7, 1, 1, + 7, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 24], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [9, 12], [12, 13], [12, 14], [12, 15], [15, 16], [15, + 17], [17, 18], [17, 21], [18, 19], [18, 20], [21, 22], [21, 23], [24, 25]] + head_atom: null + tail_atom: 24 + NASN: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, OD1, ND2, HD21, HD22, C, + O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 16.0, 14.01, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 8, 7, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 14], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [11, 12], [11, 13], [14, 15]] + head_atom: null + tail_atom: 14 + NASP: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, OD1, OD2, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 16.0, 16.0, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 8, 8, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 12], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [12, 13]] + head_atom: null + tail_atom: 12 + NCYS: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, SG, HG, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 32.06, + 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 16, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 11], [6, 7], [6, 8], + [6, 9], [9, 10], [11, 12]] + head_atom: null + tail_atom: 11 + NGLN: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, OE1, NE2, + HE21, HE22, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 1.008, 1.008, 12.01, 16.0, 14.01, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 8, 7, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 17], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [9, 12], [12, 13], [12, 14], [14, 15], [14, 16], [17, + 18]] + head_atom: null + tail_atom: 17 + NGLU: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, OE1, OE2, + C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 1.008, 1.008, 12.01, 16.0, 16.0, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 8, 8, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 15], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [9, 12], [12, 13], [12, 14], [15, 16]] + head_atom: null + tail_atom: 15 + NGLY: + atom_name: [N, H1, H2, H3, CA, HA2, HA3, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 7], [7, 8]] + head_atom: null + tail_atom: 7 + NHID: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, ND1, HD1, CE1, HE1, NE2, + CD2, HD2, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 14.01, 1.008, 12.01, 1.008, 14.01, 12.01, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 7, 1, 6, 1, 7, 6, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 17], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 15], [10, 11], [10, 12], [12, 13], [12, 14], [14, 15], + [15, 16], [17, 18]] + head_atom: null + tail_atom: 17 + NHIS: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, ND1, CE1, HE1, NE2, HE2, + CD2, HD2, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 14.01, 12.01, 1.008, 14.01, 1.008, 12.01, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 7, 6, 1, 7, 1, 6, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 17], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 15], [10, 11], [11, 12], [11, 13], [13, 14], [13, 15], + [15, 16], [17, 18]] + head_atom: null + tail_atom: 17 + NILE: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB, CG2, HG21, HG22, HG23, CG1, HG12, + HG13, CD1, HD11, HD12, HD13, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, + 1.008, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 19], [6, 7], [6, 8], + [6, 12], [8, 9], [8, 10], [8, 11], [12, 13], [12, 14], [12, 15], [15, 16], [15, + 17], [15, 18], [19, 20]] + head_atom: null + tail_atom: 19 + NLEU: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, HG, CD1, HD11, HD12, HD13, + CD2, HD21, HD22, HD23, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 19], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [9, 15], [11, 12], [11, 13], [11, 14], [15, 16], [15, + 17], [15, 18], [19, 20]] + head_atom: null + tail_atom: 19 + NLYS: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, HD2, HD3, + CE, HE2, HE3, NZ, HZ1, HZ2, HZ3, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 1.008, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 14.01, 1.008, 1.008, + 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 7, 1, 1, + 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 22], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [9, 12], [12, 13], [12, 14], [12, 15], [15, 16], [15, + 17], [15, 18], [18, 19], [18, 20], [18, 21], [22, 23]] + head_atom: null + tail_atom: 22 + NMET: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, HG2, HG3, SD, CE, HE1, HE2, + HE3, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 1.008, 1.008, 32.06, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 1, 1, 16, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 17], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [9, 12], [12, 13], [13, 14], [13, 15], [13, 16], [17, + 18]] + head_atom: null + tail_atom: 17 + NPHE: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, CD1, HD1, CE1, HE1, CZ, + HZ, CE2, HE2, CD2, HD2, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, + 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, + 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 20], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 18], [10, 11], [10, 12], [12, 13], [12, 14], [14, 15], + [14, 16], [16, 17], [16, 18], [18, 19], [20, 21]] + head_atom: null + tail_atom: 20 + NPRO: + atom_name: [N, H2, H3, CD, HD2, HD3, CG, HG2, HG3, CB, HB2, HB3, CA, HA, C, O] + atom_mass: [14.01, 1.008, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 12.01, + 1.008, 1.008, 12.01, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 6, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 12], [3, 4], [3, 5], [3, 6], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 11], [9, 12], [12, 13], [12, 14], [14, 15]] + head_atom: null + tail_atom: 14 + NSER: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, OG, HG, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 16.0, + 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 8, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 11], [6, 7], [6, 8], + [6, 9], [9, 10], [11, 12]] + head_atom: null + tail_atom: 11 + NTHR: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB, CG2, HG21, HG22, HG23, OG1, HG1, C, + O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, + 1.008, 1.008, 16.0, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 6, 1, 1, 1, 8, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 14], [6, 7], [6, 8], + [6, 12], [8, 9], [8, 10], [8, 11], [12, 13], [14, 15]] + head_atom: null + tail_atom: 14 + NTRP: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, CD1, HD1, NE1, HE1, CE2, + CZ2, HZ2, CH2, HH2, CZ3, HZ3, CE3, HE3, CD2, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 12.01, 1.008, 14.01, 1.008, 12.01, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, + 12.01, 1.008, 12.01, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 6, 1, 7, 1, 6, 6, 1, 6, 1, 6, 1, + 6, 1, 6, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 24], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 23], [10, 11], [10, 12], [12, 13], [12, 14], [14, 15], + [14, 23], [15, 16], [15, 17], [17, 18], [17, 19], [19, 20], [19, 21], [21, 22], + [21, 23], [24, 25]] + head_atom: null + tail_atom: 24 + NTYR: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB2, HB3, CG, CD1, HD1, CE1, HE1, CZ, + OH, HH, CE2, HE2, CD2, HD2, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, + 12.01, 1.008, 12.01, 1.008, 12.01, 16.0, 1.008, 12.01, 1.008, 12.01, 1.008, + 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 1, 6, 6, 1, 6, 1, 6, 8, 1, 6, 1, 6, 1, + 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 21], [6, 7], [6, 8], + [6, 9], [9, 10], [9, 19], [10, 11], [10, 12], [12, 13], [12, 14], [14, 15], + [14, 17], [15, 16], [17, 18], [17, 19], [19, 20], [21, 22]] + head_atom: null + tail_atom: 21 + NVAL: + atom_name: [N, H1, H2, H3, CA, HA, CB, HB, CG1, HG11, HG12, HG13, CG2, HG21, + HG22, HG23, C, O] + atom_mass: [14.01, 1.008, 1.008, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, + 1.008, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0] + atomic_number: [7, 1, 1, 1, 6, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 1, 6, 8] + bond: [[0, 1], [0, 2], [0, 3], [0, 4], [4, 5], [4, 6], [4, 16], [6, 7], [6, 8], + [6, 12], [8, 9], [8, 10], [8, 11], [12, 13], [12, 14], [12, 15], [16, 17]] + head_atom: null + tail_atom: 16 + CALA: + atom_name: [N, H, CA, HA, CB, HB1, HB2, HB3, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0, + 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 8], [4, 5], [4, 6], [4, 7], [8, 9], + [8, 10]] + head_atom: 0 + tail_atom: null + CARG: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, HD2, HD3, NE, HE, CZ, + NH1, HH11, HH12, NH2, HH21, HH22, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 1.008, 1.008, 14.01, 1.008, 12.01, 14.01, 1.008, 1.008, 14.01, 1.008, + 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 7, 1, 6, 7, 1, 1, 7, 1, + 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 22], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [10, 13], [13, 14], [13, 15], [15, 16], + [15, 19], [16, 17], [16, 18], [19, 20], [19, 21], [22, 23], [22, 24]] + head_atom: 0 + tail_atom: null + CASN: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, OD1, ND2, HD21, HD22, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 16.0, 14.01, + 1.008, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 8, 7, 1, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 12], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [9, 10], [9, 11], [12, 13], [12, 14]] + head_atom: 0 + tail_atom: null + CASP: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, OD1, OD2, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 16.0, 16.0, + 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 8, 8, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 10], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [10, 11], [10, 12]] + head_atom: 0 + tail_atom: null + CCYS: + atom_name: [N, H, CA, HA, CB, HB2, HB3, SG, HG, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 32.06, 1.008, 12.01, + 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 16, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 9], [4, 5], [4, 6], [4, 7], [7, 8], + [9, 10], [9, 11]] + head_atom: 0 + tail_atom: null + CGLN: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, OE1, NE2, HE21, HE22, + C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 16.0, 14.01, 1.008, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 8, 7, 1, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 15], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [12, 13], [12, 14], [15, 16], [15, 17]] + head_atom: 0 + tail_atom: null + CGLU: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, OE1, OE2, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 16.0, 16.0, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 8, 8, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 13], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [13, 14], [13, 15]] + head_atom: 0 + tail_atom: null + CGLY: + atom_name: [N, H, CA, HA2, HA3, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 5], [5, 6], [5, 7]] + head_atom: 0 + tail_atom: null + CHID: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, ND1, HD1, CE1, HE1, NE2, CD2, HD2, + C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 14.01, 1.008, + 12.01, 1.008, 14.01, 12.01, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 7, 1, 6, 1, 7, 6, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 15], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 13], [8, 9], [8, 10], [10, 11], [10, 12], [12, 13], [13, 14], [15, 16], + [15, 17]] + head_atom: 0 + tail_atom: null + CHIS: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, ND1, CE1, HE1, NE2, HE2, CD2, HD2, + C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 14.01, 12.01, + 1.008, 14.01, 1.008, 12.01, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 7, 6, 1, 7, 1, 6, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 15], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 13], [8, 9], [9, 10], [9, 11], [11, 12], [11, 13], [13, 14], [15, 16], [15, + 17]] + head_atom: 0 + tail_atom: null + CILE: + atom_name: [N, H, CA, HA, CB, HB, CG2, HG21, HG22, HG23, CG1, HG12, HG13, CD1, + HD11, HD12, HD13, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, + 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 6, 1, 1, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 17], [4, 5], [4, 6], [4, 10], [6, + 7], [6, 8], [6, 9], [10, 11], [10, 12], [10, 13], [13, 14], [13, 15], [13, 16], + [17, 18], [17, 19]] + head_atom: 0 + tail_atom: null + CLEU: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG, CD1, HD11, HD12, HD13, CD2, HD21, + HD22, HD23, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 12.01, + 1.008, 1.008, 1.008, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 17], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 13], [9, 10], [9, 11], [9, 12], [13, 14], [13, 15], [13, 16], [17, + 18], [17, 19]] + head_atom: 0 + tail_atom: null + CLYS: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, CD, HD2, HD3, CE, HE2, + HE3, NZ, HZ1, HZ2, HZ3, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 14.01, 1.008, 1.008, 1.008, 12.01, + 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 7, 1, 1, 1, 6, + 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 20], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [10, 13], [13, 14], [13, 15], [13, 16], + [16, 17], [16, 18], [16, 19], [20, 21], [20, 22]] + head_atom: 0 + tail_atom: null + CMET: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, HG2, HG3, SD, CE, HE1, HE2, HE3, + C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 32.06, 12.01, 1.008, 1.008, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 1, 1, 16, 6, 1, 1, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 15], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [11, 12], [11, 13], [11, 14], [15, 16], [15, 17]] + head_atom: 0 + tail_atom: null + CPHE: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, CD1, HD1, CE1, HE1, CZ, HZ, CE2, + HE2, CD2, HD2, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 12.01, 1.008, + 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 18], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 16], [8, 9], [8, 10], [10, 11], [10, 12], [12, 13], [12, 14], [14, 15], + [14, 16], [16, 17], [18, 19], [18, 20]] + head_atom: 0 + tail_atom: null + CPRO: + atom_name: [N, CD, HD2, HD3, CG, HG2, HG3, CB, HB2, HB3, CA, HA, C, O, OXT] + atom_mass: [14.01, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, 12.01, 1.008, 1.008, + 12.01, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 6, 1, 1, 6, 1, 1, 6, 1, 1, 6, 1, 6, 8, 8] + bond: [[0, 1], [0, 10], [1, 2], [1, 3], [1, 4], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 9], [7, 10], [10, 11], [10, 12], [12, 13], [12, 14]] + head_atom: 0 + tail_atom: null + CSER: + atom_name: [N, H, CA, HA, CB, HB2, HB3, OG, HG, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 16.0, 1.008, 12.01, + 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 8, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 9], [4, 5], [4, 6], [4, 7], [7, 8], + [9, 10], [9, 11]] + head_atom: 0 + tail_atom: null + CTHR: + atom_name: [N, H, CA, HA, CB, HB, CG2, HG21, HG22, HG23, OG1, HG1, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, + 16.0, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 6, 1, 1, 1, 8, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 12], [4, 5], [4, 6], [4, 10], [6, + 7], [6, 8], [6, 9], [10, 11], [12, 13], [12, 14]] + head_atom: 0 + tail_atom: null + CTRP: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, CD1, HD1, NE1, HE1, CE2, CZ2, HZ2, + CH2, HH2, CZ3, HZ3, CE3, HE3, CD2, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 12.01, 1.008, + 14.01, 1.008, 12.01, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, + 12.01, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 6, 1, 7, 1, 6, 6, 1, 6, 1, 6, 1, 6, 1, + 6, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 22], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 21], [8, 9], [8, 10], [10, 11], [10, 12], [12, 13], [12, 21], [13, 14], + [13, 15], [15, 16], [15, 17], [17, 18], [17, 19], [19, 20], [19, 21], [22, 23], + [22, 24]] + head_atom: 0 + tail_atom: null + CTYR: + atom_name: [N, H, CA, HA, CB, HB2, HB3, CG, CD1, HD1, CE1, HE1, CZ, OH, HH, CE2, + HE2, CD2, HD2, C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 12.01, 12.01, 1.008, + 12.01, 1.008, 12.01, 16.0, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 1, 6, 6, 1, 6, 1, 6, 8, 1, 6, 1, 6, 1, 6, 8, + 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 19], [4, 5], [4, 6], [4, 7], [7, 8], + [7, 17], [8, 9], [8, 10], [10, 11], [10, 12], [12, 13], [12, 15], [13, 14], + [15, 16], [15, 17], [17, 18], [19, 20], [19, 21]] + head_atom: 0 + tail_atom: null + CVAL: + atom_name: [N, H, CA, HA, CB, HB, CG1, HG11, HG12, HG13, CG2, HG21, HG22, HG23, + C, O, OXT] + atom_mass: [14.01, 1.008, 12.01, 1.008, 12.01, 1.008, 12.01, 1.008, 1.008, 1.008, + 12.01, 1.008, 1.008, 1.008, 12.01, 16.0, 16.0] + atomic_number: [7, 1, 6, 1, 6, 1, 6, 1, 1, 1, 6, 1, 1, 1, 6, 8, 8] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 14], [4, 5], [4, 6], [4, 10], [6, + 7], [6, 8], [6, 9], [10, 11], [10, 12], [10, 13], [14, 15], [14, 16]] + head_atom: 0 + tail_atom: null + ACE: + atom_name: [H1, CH3, H2, H3, C, O] + atom_mass: [1.008, 12.01, 1.008, 1.008, 12.01, 16.0] + atomic_number: [1, 6, 1, 1, 6, 8] + bond: [[0, 1], [1, 2], [1, 3], [1, 4], [4, 5]] + head_atom: null + tail_atom: 4 + NME: + atom_name: [N, H, CH3, HH31, HH32, HH33] + atom_mass: [14.01, 1.008, 12.01, 1.008, 1.008, 1.008] + atomic_number: [7, 1, 6, 1, 1, 1] + bond: [[0, 1], [0, 2], [2, 3], [2, 4], [2, 5]] + head_atom: 0 + tail_atom: null diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/template.py b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/template.py new file mode 100644 index 0000000000000000000000000000000000000000..890de3bbfb5ae1d1e78b07569e4cc6b5b80e4cef --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/template.py @@ -0,0 +1,148 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Read template +""" + +import os +from typing import Union, Tuple +import numpy as np +from numpy import ndarray + +from ..data import update_dict, read_yaml + + +def get_template(template: Union[str, dict, list], residue_name: str = None) -> dict: + """ + Get molecular template. + + Args: + template (Union[str, dict, list]): The file name of template. + residue_name (str): Residue name. + + Returns: + template (dict), Molecular template. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + if template is None or not template: + return None + + if isinstance(template, str): + if os.path.exists(template): + filename = template + else: + directory, _ = os.path.split(os.path.realpath(__file__)) + filename = os.path.join(directory, template) + if not os.path.exists(filename): + raise ValueError('Cannot find template file: "'+template+'".') + template: dict = read_yaml(filename) + elif isinstance(template, (list, tuple)): + template_ = {} + for temp in template: + temp = get_template(temp) + template_ = update_dict(template_, temp) + template = template_ + elif not isinstance(template, dict): + raise TypeError('The type of template must be str, dict or list but got: ' + str(type(template))) + + if 'template' in template.keys(): + template = get_template(template.get('template')) + + if 'base' in template.keys(): + base = get_template(template.pop('base')) + template = update_dict(base, template) + + if residue_name is not None: + if residue_name in template.keys(): + template = template.get(residue_name) + else: + raise ValueError('Cannot find the residue name "' + str(residue_name) + + '" in template.') + + return template + +def get_template_index(template: dict, names: ndarray, key: str = 'atom_name') -> ndarray: + """ + get atom index of system according to atom names. + + Args: + template (dict): The file name of template. + names (ndarray): Residue names. + key (str): atom_name. + + Returns: + index (ndarray), atom index of system. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + reference: list = template.get(key) + index = [reference.index(name) for name in names.reshape(-1).tolist()] + index = np.array(index, np.int32).reshape(names.shape) + unknown_name = (index >= len(reference)) + if unknown_name.any(): + raise ValueError('Unknown name: ' + str(names[unknown_name])) + return index + + +def get_molecule(template: str) -> Tuple[dict, dict]: + """ + Get molecular template. + + Args: + template (str or dict): The file name of template. + + Returns: + template (dict), Molecular template. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + if isinstance(template, str): + if os.path.exists(template): + filename = template + else: + directory, _ = os.path.split(os.path.realpath(__file__)) + filename = os.path.join(directory, template) + if not os.path.exists(filename): + raise ValueError('Cannot find template file: "'+template+'".') + template: dict = read_yaml(filename) + elif not isinstance(template, dict): + raise TypeError('The type of template must be str or dict but got :' + + str(type(template))) + + if 'molecule' in template.keys(): + molecule: dict = template.get('molecule') + template: dict = get_template(template) + else: + raise ValueError('Cannot find "molecule" in template') + + for res in molecule.get('residue'): + if res not in template.keys(): + raise ValueError('Cannot find residue name "'+str(res)+'" in template!') + + return molecule, template diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water.spce.yaml b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water.spce.yaml new file mode 100644 index 0000000000000000000000000000000000000000..acdebf46bd7144fc0adedc1556e960859f168493 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water.spce.yaml @@ -0,0 +1,13 @@ +template: + base: water_3p.yaml + WAT: + atom_mass: [15.9994, 1.008, 1.008] + atom_charge: [-0.8476, 0.4238, 0.4238] +molecule: + residue: + - WAT + length_unit: nm + coordinate: + - [0.0, 0.0, 0.0] + - [0.081649043, 0.057735897, 0.0] + - [-0.081649043, 0.057735897, 0.0] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water.tip3p.yaml b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water.tip3p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58e30983556b266047762ce8160a4ec4898b9766 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water.tip3p.yaml @@ -0,0 +1,12 @@ +template: + base: water_3p.yaml + WAT: + atom_charge: [-0.834, 0.417, 0.417] +molecule: + residue: + - WAT + length_unit: nm + coordinate: + - [0.0, 0.0, 0.0] + - [0.079079641, 0.061207927, 0.0] + - [-0.079079641, 0.061207927, 0.0] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water_3p.yaml b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water_3p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..632d357f1530a8a94c07fd1a4410da436de25054 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/data/template/water_3p.yaml @@ -0,0 +1,11 @@ +template: + WAT: + atom_name: [O, H1, H2] + atom_type: [OW, HW, HW] + atom_mass: [16.00, 1.008, 1.008] + atomic_number: [8, 1, 1] + bond: + - [0, 1] + - [0, 2] + head_atom: null + tail_atom: null diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/function/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/function/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f430ceef1a4d9909aa4931164b5831af473182c7 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/function/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Functions and Operations""" + +from .functions import * +from .operations import * +from .units import * + +__all__ = [] +__all__.extend(functions.__all__) +__all__.extend(operations.__all__) +__all__.extend(units.__all__) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/function/functions.py b/MindSPONGE/applications/research/Grasp/mindsponge1/function/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..0470170b97305ed8ab1de9c8d9dab28f3acf4c0a --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/function/functions.py @@ -0,0 +1,1049 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/ ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Common functions +""" + +from typing import Union +import numpy as np +from numpy import ndarray +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import ops +from mindspore import jit as ms_function +from mindspore.ops import functional as F +from mindspore import Tensor, Parameter, context +# from mindspore.ops._grad.grad_base import bprop_getters +from mindspore.ops._grad_experimental.grad_base import bprop_getters +from mindspore.ops._utils.utils import generate_shape_index +from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like + +__all__ = [ + 'PI', + 'inv', + 'keepdim_sum', + 'keepdim_mean', + 'keepdim_prod', + 'keep_norm_last_dim', + 'norm_last_dim', + 'reduce_any', + 'reduce_all', + 'concat_last_dim', + 'concat_penulti', + 'identity', + 'pbc_box_reshape', + 'periodic_image', + 'displace_in_box', + 'vector_in_box', + 'get_vector_without_pbc', + 'get_vector_with_pbc', + 'get_vector', + 'gather_vectors', + 'gather_values', + 'calc_distance_without_pbc', + 'calc_distance_with_pbc', + 'calc_distance', + 'calc_angle_between_vectors', + 'calc_angle_without_pbc', + 'calc_angle_with_pbc', + 'calc_angle', + 'calc_torsion_for_vectors', + 'calc_torsion_without_pbc', + 'calc_torsion_with_pbc', + 'calc_torsion', + 'get_kinetic_energy', + 'get_integer', + 'get_ndarray', + 'get_tensor', + 'get_ms_array', +] + +PI = 3.141592653589793238462643383279502884197169399375105820974944592307 + +inv = ops.Inv() +keepdim_sum = ops.ReduceSum(keep_dims=True) +keepdim_mean = ops.ReduceMean(keep_dims=True) +keepdim_prod = ops.ReduceProd(keep_dims=True) +reduce_any = ops.ReduceAny() +reduce_all = ops.ReduceAll() +concat_last_dim = ops.Concat(-1) +concat_penulti = ops.Concat(-2) +identity = ops.Identity() +dyn_shape_op = ops.TensorShape() +unsorted_segment_sum = ops.UnsortedSegmentSum() + + +@bprop_getters.register(ops.Slice) +def get_bprop_slice(self): + """Bprop for slice""" + # pylint: disable=W0613 + concat = ops.Concat(axis=2) + def bprop(x, begin, size, out, dout): + # pylint: disable=W0613 + dtype = x.dtype + begin = begin[-1] + size = size[-1] + if begin != 0: + left_tensor = ops.zeros(x.shape[:-1] + (begin,), dtype) + dout = concat((left_tensor, dout)) + shape = x.shape[-1] + if shape != begin + size: + right_tensor = ops.zeros(x.shape[:-1] + (shape - begin - size,), dtype) + dout = concat((dout, right_tensor)) + return (dout, zeros_like(begin), zeros_like(size)) + + return bprop + + +def _generate_inverse_index(x_shape, axis): + x_rank = len(x_shape) + index = tuple(range(x_rank)) + if axis < 0: + axis += x_rank + perm = index[1:1 + axis] + (0,) + index[1 + axis:] + return perm + + +def _regenerate_output_shape(x_shp, ind_shp, axis): + rank = len(x_shp) + if axis < 0: + axis += rank + out_shape = x_shp[:axis] + ind_shp + x_shp[axis + 1:] + return out_shape + + +class GatherNet(ms.nn.Cell): + """Redefine bprop for gather to run unsorted_segment_sum on aicpu""" + def construct(self, data, indices, axis): + return ops.gather(data, indices, axis) + + def bprop(x, indices, axis, out, dout): + """bprop for gather""" + # pylint: disable=E0213, W0613 + orig_indices = indices + if ops.rank(dout) == 0: + dout = ops.ExpandDims()(dout, -1) + if ops.rank(indices) == 0: + indices = ops.ExpandDims()(indices, -1) + x_shp = ops.shape(x) + ind_shp = ops.shape(indices) + out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) + dout = ops.reshape(dout, out_shp) + + x_shp = ops.shape(x) + out_shp = ops.shape(dout) + ind_shp = ops.shape(indices) + perm_1 = generate_shape_index(out_shp, ind_shp, axis) + values_transpose = ops.transpose(dout, perm_1) + if F.is_sequence_value_unknown(ops.shape(x)): + params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) + else: + params_grad = unsorted_segment_sum(values_transpose, indices, ops.shape(x)[axis]) + perm_2 = _generate_inverse_index(x_shp, axis) + params_grad = ops.transpose(params_grad, perm_2) + return params_grad, zeros_like(orig_indices), zeros_like(axis) + + +gather = GatherNet() if context.get_context("device_target") == "Ascend" else ops.Gather() + + +@ms_function +def norm_last_dim(vector: Tensor) -> Tensor: + r"""Compute the norm of vector, delete the last dims + + Args: + vector (Tensor): Tensor of shape (..., D). Data type is float. + + Returns: + Tensor of shape (...,). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + D: Dimension of the simulation system. Usually is 3. + + """ + return msnp.norm(vector, axis=-1) + + +@ms_function +def keep_norm_last_dim(vector: Tensor) -> Tensor: + r"""Compute the norm of vector, keep the last dims + + Args: + vector (Tensor): Tensor of shape (..., D). Data type is float. + + Returns: + Tensor of shape (..., 1). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + D: Dimension of the simulation system. Usually is 3. + + """ + return msnp.norm(vector, axis=-1, keep_dims=True) + +@ms_function +def pbc_box_reshape(pbc_box: Tensor, ndim: int) -> Tensor: + r""" + Reshape the pbc_box as the same ndim. + + Args: + pbc_box (Tensor): Tensor of shape (B,D). Data type is float. + ndim (int): The rank (ndim) of the pbc_box. + + Returns: + pbc_box (Tensor), Tensor of shape (B,1,..,1,D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + if ndim <= 2: + return pbc_box + shape = pbc_box.shape[:1] + (1,) * (ndim - 2) + pbc_box.shape[-1:] + return ops.reshape(pbc_box, shape) + + +@ms_function +def periodic_image(position: Tensor, pbc_box: Tensor, shift: float = 0) -> Tensor: + r""" + calculate the periodic image of the PBC box. + + Args: + position (Tensor): Tensor of shape (B, ..., D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + shift (float): Shift of PBC box. Default: 0 + + Returns: + image (Tensor), Tensor of shape (B, ..., D). Data type is int32. + + Symbols: + - B: Batchsize, i.e. number of walkers in simulation. + - D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + pbc_box = pbc_box_reshape(ops.stop_gradient(pbc_box), position.ndim) + image = -ops.floor(position / pbc_box - shift) + return ops.cast(image, ms.int32) + + +@ms_function +def displace_in_box(position: Tensor, pbc_box: Tensor, shift: float = 0) -> Tensor: + r""" + displace the positions of system in a PBC box. + + Args: + position (Tensor): Tensor of shape (B, ..., D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + shift (float): Shift of PBC box. Default: 0 + + Returns: + position_in box (Tensor), Tensor of shape (B, ..., D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + pbc_box = pbc_box_reshape(ops.stop_gradient(pbc_box), position.ndim) + image = -ops.floor(position / pbc_box - shift) + return position + pbc_box * image + + +@ms_function +def vector_in_box(vector: Tensor, pbc_box: Tensor) -> Tensor: + r""" + Make the vector at the range from -0.5 box to 0.5 box + at perodic bundary condition. (-0.5box < difference < 0.5box) + + Args: + vector (Tensor): Tensor of shape (B, ..., D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + + Returns: + diff_in_box (Tensor), Tensor of shape (B, ..., D). Data type is float. + + Symbols: + - B: Batchsize, i.e. number of walkers in simulation. + - D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + pbc_box = pbc_box_reshape(pbc_box, vector.ndim) + box_nograd = ops.stop_gradient(pbc_box) + inv_box = msnp.reciprocal(box_nograd) + vector -= box_nograd * ops.floor(vector * inv_box + 0.5) + return vector * inv_box * pbc_box + +@ms_function +def get_vector_without_pbc(initial: Tensor, terminal: Tensor, _pbc_box=None) -> Tensor: + r""" + Compute vector from initial point to terminal point without perodic bundary condition. + + Args: + initial (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of initial point. + terminal (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of terminal point. + _pbc_box (None): Dummy. + + Returns: + vector (Tensor), Tensor of shape (B, ..., D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + #pylint: disable=invalid-name + + return terminal - initial + + +@ms_function +def get_vector_with_pbc(initial: Tensor, terminal: Tensor, pbc_box: Tensor) -> Tensor: + r""" + Compute vector from initial point to terminal point at perodic bundary condition. + + Args: + initial (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of initial point. + terminal (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of terminal point. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + + Returns: + vector (Tensor), Tensor of shape (B, ..., D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + return vector_in_box(terminal-initial, pbc_box) + +@ms_function +def get_vector(initial: Tensor, terminal: Tensor, pbc_box: Tensor = None) -> Tensor: + r""" + Compute vector from initial point to terminal point. + + Args: + initial (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of initial point. + terminal (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of terminal point. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Default: None + + Returns: + vector (Tensor), Tensor of shape (B, ..., D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + vector = terminal - initial + if pbc_box is None: + return vector + return vector_in_box(vector, pbc_box) + + +@ms_function +def gather_vectors(tensor: Tensor, index: Tensor) -> Tensor: + r""" + Gather vectors from the penultimate axis (axis=-2) of the tensor according to index. + + Args: + tensor (Tensor): Tensor of shape (B, A, D). + index (Tensor): Tensor of shape (B, ...,). Data type is int. + + Returns: + vector (Tensor), Tensor of shape (B, ..., D). + + Symbols: + B: Batch size. + A: Atom nums. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + if index.shape[0] == 1: + index1 = ops.reshape(index, index.shape[1:]) + if tensor.shape[0] == 1: + tensor1 = ops.reshape(tensor, tensor.shape[1:]) + res = gather(tensor1, index1, len(tensor1.shape) - 2) + res = ops.reshape(res, (1,) + res.shape) + return res + return gather(tensor, index1, len(tensor.shape) - 2) + if tensor.shape[0] == 1: + tensor1 = ops.reshape(tensor, tensor.shape[1:]) + return gather(tensor1, index, len(tensor1.shape) - 2) + + # (B, N, M): + shape0 = index.shape + # (B, N*M, 1) <- (B, N, M): + index = ops.reshape(index, (shape0[0], -1, 1)) + # (B, N*M, D) <- (B, N, D): + neigh_atoms = msnp.take_along_axis(tensor, index, axis=-2) + # (B, N, M, D) <- (B, N, M) + (D,): + output_shape = shape0 + tensor.shape[-1:] + + # (B, N, M, D): + return ops.reshape(neigh_atoms, output_shape) + + +@ms_function +def gather_values(tensor: Tensor, index: Tensor) -> Tensor: + r""" + Gather values from the last axis (axis=-1) of the tensor according to index. + + Args: + tensor (Tensor): Tensor of shape (B, X). + index (Tensor): Tensor of shape (B, ...,). Data type is int. + + Returns: + value (Tensor), Tensor of shape (B, ...,). + + Symbols: + B: Batch size. + X: Any value. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + if index.shape[0] == 1: + index1 = ops.reshape(index, index.shape[1:]) + if tensor.shape[0] == 1: + tensor1 = ops.reshape(tensor, tensor.shape[1:]) + res = gather(tensor1, index1, len(tensor1.shape) - 1) + res = ops.reshape(res, (1,) + res.shape) + return res + return gather(tensor, index1, len(tensor.shape) - 1) + if tensor.shape[0] == 1: + tensor1 = ops.reshape(tensor, tensor.shape[1:]) + return gather(tensor1, index, len(tensor1.shape) - 1) + + # (B, N, M): + origin_shape = index.shape + # (B, N*M) <- (B, N, M): + index = ops.reshape(index, (origin_shape[0], -1)) + + # (B, N*M): + neigh_values = ops.gather_d(tensor, -1, index) + + # (B, N, M): + return ops.reshape(neigh_values, origin_shape) + + +@ms_function +def calc_distance_without_pbc(position_a: Tensor, position_b: Tensor, _pbc_box=None) -> Tensor: + r""" + Compute distance between position A and B without perodic bundary condition. + + Args: + position_a (Tensor): Tensor of shape (..., D). Data type is float. + position_b (Tensor): Tensor of shape (..., D). Data type is float. + _pbc_box (None): Dummy. + + Returns: + distance (Tensor), Tensor of shape (..., 1). Data type is float. + + Symbols: + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore + >>> from mindsponge.function import calc_distance_without_pbc + >>> from mindspore.common.tensor import Tensor + >>> A = Tensor([[1.,2.,3.]]) + >>> B = Tensor([[1.,1.,1.]]) + >>> print (calc_distance_with_pbc(A,B)) + tensor(shape=[1,1],dtype = Float32, value = [[2.236060801]]) + """ + #pylint: disable=invalid-name + + vec = get_vector_without_pbc(position_a, position_b) + return msnp.norm(vec, axis=-1, keepdims=True) + + +@ms_function +def calc_distance_with_pbc(position_a: Tensor, position_b: Tensor, pbc_box: Tensor) -> Tensor: + r""" + Compute distance between position A and B at perodic bundary condition. + + Args: + position_a (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_b (Tensor): Tensor of shape (B, ..., D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + + Returns: + distance (Tensor), Tensor of shape (B, ..., 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore + >>> from mindsponge.function import calc_distance_with_pbc + >>> from mindspore.common.tensor import Tensor + >>> A = Tensor([[1.,2.,3.]]) + >>> B = Tensor([[1.,1.,1.]]) + >>> pbc_box = Tensor([[0.7,0.7,0.7]]) + >>> print (calc_distance_with_pbc(A,B,pbc_box)) + tensor(shape=[1,1],dtype = Float32, value = [[3.16227734e-01]]) + """ + + vec = get_vector_with_pbc(position_a, position_b, pbc_box) + return msnp.norm(vec, axis=-1, keepdims=True) + + +@ms_function +def calc_distance(position_a: Tensor, position_b: Tensor, pbc_box: Tensor = None) -> Tensor: + r""" + Compute distance between position A and B. + + Args: + position_a (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_b (Tensor): Tensor of shape (B, ..., D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + + Returns: + distance (Tensor), Tensor of shape (B, ..., 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore + >>> from mindsponge.function import calc_distance + >>> from mindspore.common.tensor import Tensor + >>> A = Tensor([[1.,2.,3.]]) + >>> B = Tensor([[1.,1.,1.]]) + >>> pbc_box = Tensor([[0.7,0.7,0.7]]) + >>> print (calc_distance(A,B,pbc_box)) + tensor(shape=[1,1],dtype = Float32, value = [[3.16227734e-01]]) + >>> print (calc_distance(A,B)) + tensor(shape=[1,1],dtype = Float32, value = [[2.236060801]]) + """ + + vec = get_vector_without_pbc(position_a, position_b) + if pbc_box is not None: + vec = vector_in_box(vec, pbc_box) + return msnp.norm(vec, axis=-1, keepdims=True) + + +@ms_function +def calc_angle_between_vectors(vector1: Tensor, vector2: Tensor) -> Tensor: + r""" + Compute angle between two vectors. For vector :math:`\vec {V_1} = (x_1, x_2, x_3, ..., x_n)` + and :math:`\vec {V_2} = (y_1, y_2, y_3, ..., y_n)` , the formula is + + .. math:: + + \theta = \arccos {\frac{|x_1y_1 + x_2y_2 + \cdots + x_ny_n|}{\sqrt{x_1^2 + x_2^2 + + \cdots + x_n^2}\sqrt{y_1^2 + y_2^2 + \cdots + y_n^2}}} + + Args: + vector1 (Tensor): Tensor of shape :math:`(..., D)` . Data type is float. + vector1 (Tensor): Tensor of shape :math:`(..., D)` . Data type is float. + + Returns: + angle (Tensor), Tensor of shape :math:`(..., 1)`. Data type is float. + + Symbols: + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> import mindspore + >>> from mindspore import Tensor + >>> a = Tensor([[1., 2., 3.], [1., 2., 3.]]) + >>> b = Tensor([[1., 1., 1.], [2., 2., 2.]]) + >>> print(mindsponge.function.calc_angle_between_vectors(a, b)) + Tensor(shape=[2, 1], dtype=Float64, value= + [[3.87596687e-01], + [3.87596687e-01]]) + """ + + # [..., 1] <- [..., D] + dis1 = msnp.norm(vector1, axis=-1, keepdims=True) + dis2 = msnp.norm(vector2, axis=-1, keepdims=True) + dot12 = keepdim_sum(vector1 * vector2, -1) + # [..., 1] + cos_theta = dot12 / dis1 / dis2 + return ops.acos(cos_theta) + + +@ms_function +def calc_angle_without_pbc(position_a: Tensor, position_b: Tensor, position_c: Tensor) -> Tensor: + r""" + Compute angle :math:`\angle ABC` formed by three positions A, B, C without periodic boundary condition. + + Calculate the coordinates of vectors :math:`\vec{BA}` and :math:`\vec{BC}` according to the coordinates of A, B, C + without periodic boundary condition, then use the vectors to calculate the angle. + + Args: + position_a (Tensor): Tensor of shape :math:`(..., D)` . Data type is float. + position_b (Tensor): Tensor of shape :math:`(..., D)` . Data type is float. + position_c (Tensor): Tensor of shape :math:`(..., D)` . Data type is float. + + Returns: + angle (Tensor), Tensor of shape (..., 1). Data type is float. + + Symbols: + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> import mindspore + >>> from mindspore import Tensor + >>> A = Tensor([[1., 2., 3.]]) + >>> B = Tensor([[1., 1., 1.]]) + >>> C = Tensor([[4., 5., 6.]]) + >>> print(mindsponge.function.calc_angle_without_pbc(A, B, C)) + Tensor(shape=[1, 1], dtype=Float32, value= + [[ 4.83361423e-01]]) + """ + + # (...,D) + vec_ba = get_vector_without_pbc(position_b, position_a) + vec_bc = get_vector_without_pbc(position_b, position_c) + return calc_angle_between_vectors(vec_ba, vec_bc) + + +@ms_function +def calc_angle_with_pbc(position_a: Tensor, position_b: Tensor, position_c: Tensor, pbc_box: Tensor) -> Tensor: + r""" + Compute angle :math:`\angle ABC` formed by three positions A, B, C with periodic boundary condition. + Put in the coordinates of A, B, C and pbc_box, and get the angle :math:`\angle ABC` . + + Calculate the coordinates of vectors :math:`\vec{BA}` and :math:`\vec{BC}` according to the coordinates of A, B, C + with periodic boundary condition, then use the vectors to calculate the angle. + + Args: + position_a (Tensor): Tensor of shape :math:`(B, ..., D)` . Data type is float. + position_b (Tensor): Tensor of shape :math:`(B, ..., D)` . Data type is float. + position_c (Tensor): Tensor of shape :math:`(B, ..., D)` . Data type is float. + pbc_box (Tensor): Tensor of shape :math:`(B, D)` . Data type is float. + + Returns: + angle (Tensor), Tensor of shape :math:`(B, ..., 1)` . Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> import mindspore + >>> from mindspore import Tensor + >>> A = Tensor([[1., 2., 3.]]) + >>> B = Tensor([[1., 1., 1.]]) + >>> C = Tensor([[4., 5., 6.]]) + >>> pbc_box = Tensor([[0.7, 0.7, 0.7]]) + >>> print(mindsponge.function.calc_angle_with_pbc(A, B, C, pbc_box=pbc_box)) + Tensor(shape=[1, 1], dtype=Float32, value= + [[ 2.40069723e+00]]) + """ + + # (B, ..., D) + vec_ba = get_vector_with_pbc(position_b, position_a, pbc_box) + vec_bc = get_vector_with_pbc(position_b, position_c, pbc_box) + return calc_angle_between_vectors(vec_ba, vec_bc) + + +@ms_function +def calc_angle(position_a, position_b: Tensor, position_c: Tensor, pbc_box: Tensor = None) -> Tensor: + r""" + Compute angle :math:`\angle ABC` formed by three positions A, B, C. + + If pbc_box is provided, calculate the angle according to the coordinates with periodic boundary condition. + If pbc_box is None, calculate the angle according to the coordinates without periodic boundary condition. + + Finally return the angle between vector :math:`\vec{BA}` and vector :math:`\vec{BC}` . + + Args: + position_a (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_b (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_c (Tensor): Tensor of shape (B, ..., D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. Default: None + + Returns: + angle (Tensor), Tensor of shape (B, ..., 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindsponge + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindsponge.function import calc_angle + >>> A = Tensor([[1., 2., 3.]]) + >>> B = Tensor([[1., 1., 1.]]) + >>> C = Tensor([[4., 5., 6.]]) + >>> print(calc_angle(A, B, C, pbc_box=None)) + Tensor(shape=[1, 1], dtype=Float32, value= + [[ 4.83361423e-01]]) + >>> A = Tensor([[1., 2., 3.]]) + >>> B = Tensor([[1., 1., 1.]]) + >>> C = Tensor([[4., 5., 6.]]) + >>> pbc_box = Tensor([[0.7, 0.7, 0.7]]) + >>> print(calc_angle(A, B, C, pbc_box=pbc_box)) + Tensor(shape=[1, 1], dtype=Float32, value= + [[ 2.40069723e+00]]) + """ + + # (B, ..., D) + if pbc_box is None: + vec_ba = get_vector_without_pbc(position_b, position_a) + vec_bc = get_vector_without_pbc(position_b, position_c) + else: + vec_ba = get_vector_with_pbc(position_b, position_a, pbc_box) + vec_bc = get_vector_with_pbc(position_b, position_c, pbc_box) + return calc_angle_between_vectors(vec_ba, vec_bc) + + +@ms_function +def calc_torsion_for_vectors(vector1: Tensor, vector2: Tensor, vector3: Tensor) -> Tensor: + r""" + Compute torsion angle formed by three vectors. + + Args: + vector1 (Tensor): Tensor of shape (..., D). Data type is float. + vector2 (Tensor): Tensor of shape (..., D). Data type is float. + vector3 (Tensor): Tensor of shape (..., D). Data type is float. + + Returns: + torsion (Tensor), Tensor of shape (..., 1). Data type is float. + + Symbols: + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + # (B, ..., D) <- (B,...,1) + v2norm = msnp.norm(vector2, axis=-1, keepdims=True) + # (B, ..., D) = (B, ..., D) / (...,1) + norm_vec2 = vector2 / v2norm + + # (B, ..., D) + vec_a = msnp.cross(norm_vec2, vector1) + vec_b = msnp.cross(vector3, norm_vec2) + cross_ab = msnp.cross(vec_a, vec_b) + + # (B,...,1) + sin_phi = keepdim_sum(cross_ab*norm_vec2, -1) + cos_phi = keepdim_sum(vec_a*vec_b, -1) + + return ops.atan2(-sin_phi, cos_phi) + + +@ms_function +def calc_torsion_without_pbc(position_a: Tensor, + position_b: Tensor, + position_c: Tensor, + position_d: Tensor + ) -> Tensor: + r""" + Compute torsion angle formed by four positions A-B-C-D without periodic boundary condition. + + Args: + position_a (Tensor): Tensor of shape (..., D). Data type is float. + position_b (Tensor): Tensor of shape (..., D). Data type is float. + position_c (Tensor): Tensor of shape (..., D). Data type is float. + position_d (Tensor): Tensor of shape (..., D). Data type is float. + + Returns: + torsion (Tensor), Tensor of shape (..., 1). Data type is float. + + Symbols: + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + vec_ba = get_vector_without_pbc(position_b, position_a) + vec_cb = get_vector_without_pbc(position_c, position_b) + vec_dc = get_vector_without_pbc(position_d, position_c) + return calc_torsion_for_vectors(vec_ba, vec_cb, vec_dc) + + +@ms_function +def calc_torsion_with_pbc(position_a: Tensor, + position_b: Tensor, + position_c: Tensor, + position_d: Tensor, + pbc_box: Tensor + ) -> Tensor: + r""" + Compute torsion angle formed by four positions A-B-C-D at periodic boundary condition. + + Args: + position_a (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_b (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_c (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_d (Tensor): Tensor of shape (B, ..., D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + + Returns: + torsion (Tensor), Tensor of shape (B, ..., 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + vec_ba = get_vector_with_pbc(position_b, position_a, pbc_box) + vec_cb = get_vector_with_pbc(position_c, position_b, pbc_box) + vec_dc = get_vector_with_pbc(position_d, position_c, pbc_box) + return calc_torsion_for_vectors(vec_ba, vec_cb, vec_dc) + + +@ms_function +def calc_torsion(position_a: Tensor, + position_b: Tensor, + position_c: Tensor, + position_d: Tensor, + pbc_box: Tensor = None + ) -> Tensor: + r""" + Compute torsion angle formed by four positions A-B-C-D. + + Args: + position_a (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_b (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_c (Tensor): Tensor of shape (B, ..., D). Data type is float. + position_d (Tensor): Tensor of shape (B, ..., D). Data type is float. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + + Returns: + torsion (Tensor), Tensor of shape (B, ..., 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + if pbc_box is None: + vec_ba = get_vector_without_pbc(position_b, position_a) + vec_cb = get_vector_without_pbc(position_c, position_b) + vec_dc = get_vector_without_pbc(position_d, position_c) + else: + vec_ba = get_vector_with_pbc(position_b, position_a, pbc_box) + vec_cb = get_vector_with_pbc(position_c, position_b, pbc_box) + vec_dc = get_vector_with_pbc(position_d, position_c, pbc_box) + + return calc_torsion_for_vectors(vec_ba, vec_cb, vec_dc) + + +@ms_function +def get_kinetic_energy(mass: Tensor, velocity: Tensor) -> Tensor: + r""" + Compute kinectic energy of the simulation system. + + Args: + mass (Tensor): Tensor of shape (B, A). Data type is float. + Mass of the atoms in system. + velocity (Tensor): Tensor of shape (B, A, D). Data type is float. + Velocities of the atoms in system. + + Returns: + kinectics (Tensor), Tensor of shape (B). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms in the simulation system. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + # (B, A) <- (B, A, D) + v2 = ops.reduce_sum(velocity*velocity, -1) + # (B, A) * (B, A) + kinectics = 0.5 * mass * v2 + # (B) <- (B, A) + return ops.reduce_sum(kinectics, -1) + + +def get_integer(value: Union[int, Tensor, Parameter, ndarray]) -> int: + r""" + get integer type of the input value. + + Args: + value (Union[int, Tensor, Parameter, ndarray]): Input value. + + Returns: + integer, the integer type of the input value. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if value is None: + return None + if isinstance(value, Tensor): + value = value.asnumpy() + return int(value) + + +def get_ndarray(value: Union[Tensor, Parameter, ndarray, list, tuple], dtype: type = None) -> ndarray: + r""" + get ndarray type of the input value. + + Args: + value (Union[Tensor, Parameter, ndarray]): Input value. + dtype (type): Data type. Default: None + + Returns: + array (ndarray). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if value is None: + return None + if isinstance(value, (Tensor, Parameter)): + value = value.asnumpy() + if dtype is not None: + value = value.astype(dtype) + else: + value = np.array(value, dtype) + return value + + +def get_tensor(value: Union[Tensor, Parameter, ndarray, list, tuple], dtype: type = None) -> Tensor: + r""" + get mindspore.Tensor type of the input value. + + Args: + value (Union[Tensor, Parameter, ndarray, list, tuple]): Input value + dtype (type): Data type. Default: None + + Returns: + tensor (Tensor) + + """ + if value is None: + return None + + if isinstance(value, (list, tuple, ndarray)): + value = Tensor(value, dtype) + else: + if isinstance(value, Parameter): + value = identity(value) + elif not isinstance(value, Tensor): + raise TypeError('The type of input value must be Tensor, Parameter, ' + 'ndarray, list or tuple but got: ' + str(type(value))) + if dtype is not None: + value = ops.cast(value, dtype) + + return value + + +def get_ms_array(value: Union[Tensor, Parameter, ndarray, list, tuple], dtype: type = None) -> Union[Tensor, Parameter]: + r""" + get mindspore.Tensor type of the input value. + + Args: + value (Union[Tensor, Parameter, ndarray, list, tuple]): Input value + + Returns: + array (Tensor or Parameter) + + """ + + if value is None: + return None + + if isinstance(value, (Tensor, Parameter)): + if dtype is not None and value.dtype != dtype: + value = ops.cast(value, dtype) + return value + + if isinstance(value, (list, tuple, np.ndarray)): + return Tensor(value, dtype) + + return None \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/function/operations.py b/MindSPONGE/applications/research/Grasp/mindsponge1/function/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..fb65b48dfa42c3f34fbe8045570532819ab3a754 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/function/operations.py @@ -0,0 +1,465 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Common operations +""" + +import numpy as np +import mindspore as ms +from mindspore import numpy as msnp +from mindspore.ops import functional as F +from mindspore import ops, nn +from mindspore import Tensor +from mindspore.nn import Cell + +from . import functions as func +from .units import Units, global_units + +__all__ = [ + 'GetVector', + 'GetDistance', + 'VelocityGenerator', + 'GetDistanceShift', + 'GetShiftGrad', +] + + +class GetVector(Cell): + r""" + The class to get vector with or without PBC box. + + Args: + use_pbc (bool): Whether to calculate vector under periodic boundary condition. + If this is "None", it will determine whether to calculate the vector under + periodic boundary condition based on whether the pbc_box is given. + Default: None + + Returns: + vector (Tensor), Tensor of shape (B, ..., D). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, use_pbc: bool = None): + super().__init__() + + self.get_vector = self.get_vector_default + + self.use_pbc = use_pbc + self.set_pbc(use_pbc) + + def get_vector_without_pbc(self, position0, position1, pbc_box=None): + """ + get vector without periodic boundary condition. + + Args: + position0 (Tensor): Tensor of coordinate of initial point. + position1 (Tensor): Tensor of coordinate of terminal point. + pbc_box (Any): Dummy. Default: None + """ + return func.get_vector_without_pbc(position0, position1, pbc_box) + + def get_vector_with_pbc(self, position0, position1, pbc_box): + """ + get vector with periodic boundary condition. + + Args: + position0 (Tensor): Tensor of coordinate of initial point. + position1 (Tensor): Tensor of coordinate of terminal point. + pbc_box (Any): Dummy. Default: None + """ + return func.get_vector_with_pbc(position0, position1, pbc_box) + + def get_vector_default(self, position0, position1, pbc_box=None): + """ + get vector. + + Args: + position0 (Tensor): Tensor of coordinate of initial point. + position1 (Tensor): Tensor of coordinate of terminal point. + pbc_box (Any): Dummy. Default: None + """ + return func.get_vector(position0, position1, pbc_box) + + def set_pbc(self, use_pbc=None): + """ + set whether to use periodic boundary condition. + + Args: + use_pbc (bool): Whether use periodic boundary condition. Default: None + """ + self.use_pbc = use_pbc + if use_pbc is None: + self.get_vector = self.get_vector_default + else: + if use_pbc: + self.get_vector = self.get_vector_with_pbc + else: + self.get_vector = self.get_vector_without_pbc + return self + + def construct(self, initial: Tensor, terminal: Tensor, pbc_box: Tensor = None): + r""" + Compute vector from initial point to terminal point. + + Args: + initial (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of initial point + terminal (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of terminal point + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Default: None + + Returns: + vector (Tensor), Tensor of shape (B, ..., D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation + D: Dimension of the simulation system. Usually is 3. + """ + return self.get_vector(initial, terminal, pbc_box) + + +class GetDistance(Cell): + r""" + The class to calculate distance with or without PBC box. + + Args: + use_pbc (bool): Whether to calculate distance under periodic boundary condition. + If this is "None", it will determine whether to calculate the distance under + periodic boundary condition based on whether the pbc_box is given. + Default: None + + Outputs: + distance (Tensor), Tensor of shape (B, ...). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, use_pbc=None): + super().__init__() + + self.get_vector = GetVector(use_pbc) + self.norm_last_dim = nn.Norm(axis=-1, keep_dims=False) + + def set_pbc(self, use_pbc): + """ + set whether to use periodic boundary condition. + + Args: + use_pbc (bool): Whether use periodic boundary condition. + """ + self.get_vector.set_pbc(use_pbc) + return self + + def construct(self, initial: Tensor, terminal: Tensor, pbc_box: Tensor = None): + r""" + Compute the distance from initial point to terminal point. + + Args: + initial (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of initial point. + terminal (Tensor): Tensor of shape (B, ..., D). Data type is float. + Coordinate of terminal point. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Default: None + + Returns: + distance (Tensor), Tensor of shape (B, ...). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + D: Dimension of the simulation system. Usually is 3. + + """ + vector = self.get_vector(initial, terminal, pbc_box) + return self.norm_last_dim(vector) + + +class VelocityGenerator(Cell): + r""" + A class to generate velocities for atoms in system according to temperature. + + Args: + temperature (float): Temperature. Default: 300 + remove_translation (bool): Whether to calculate distance under periodic boundary condition. + Default: True + seed (int): Random seed for standard normal. Default: 0 + seed2 (int): Random seed2 for standard normal. Default: 0 + length_unit (str): Length unit. Default: None + energy_unit (str): energy unit. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + #pylint: disable=invalid-name + + def __init__(self, + temperature: float = 300, + remove_translation: bool = True, + seed: int = 0, + seed2: int = 0, + length_unit: str = None, + energy_unit: str = None, + ): + + super().__init__() + + if length_unit is None and energy_unit is None: + self.units = global_units + else: + self.units = Units(length_unit, energy_unit) + + self.temperature = Tensor(temperature, ms.float32).reshape(-1, 1, 1) + + self.standard_normal = ops.StandardNormal(seed, seed2) + + self.kb = Tensor(self.units.boltzmann, ms.float32) + self.kbT = self.kb * self.temperature + self.sigma = F.sqrt(self.kbT) + + self.kinectics_unit_scale = self.units.kinetic_ref + self.remove_translation = remove_translation + self.identity = ops.Identity() + + self.multi_temp = False + + def set_temperature(self, temperature: float): + """ + set temperature. + + Args: + temperature (float): Temperature value. + """ + self.temperature = Tensor(temperature, ms.float32).reshape(-1, 1, 1) + self.multi_temp = False + if self.temperature is not None and self.temperature.size > 1: + self.multi_temp = False + return self + + def construct(self, shape: tuple, atom_mass: Tensor, mask: Tensor = None): + r""" + Randomly generate velocities for atoms in system. + + Args: + shape (tuple): Shape of velocity. + atom_mass (Tensor): Tensor of shape (B, A). Data type is float. + Atom mass in system. + mask (Tensor): Tensor of shape (B, A). Data type is bool. + Mask for atoms. Default: None + + Returns: + velocity (Tensor), Tensor of shape (B, A, D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + """ + # (B,A,1) + atom_mass = F.expand_dims(self.identity(atom_mass), -1) + inv_mass = msnp.reciprocal(atom_mass) + velocity_scale = self.sigma * \ + msnp.sqrt(inv_mass / self.kinectics_unit_scale) + if mask is not None: + velocity_scale = msnp.where( + F.expand_dims(mask, -1), velocity_scale, 0) + + velocity = self.standard_normal(shape) * velocity_scale + + if self.remove_translation: + # (B,A,D) * (1,A,1) + momentum = atom_mass * velocity + # (1,1,1) or (B,1,1) <- (1,A,1) or (B,A,1) + + dp = func.keepdim_mean(momentum, -2) + if mask is not None: + sp = func.keepdim_sum(momentum, -2) + n = func.keepdim_sum(F.cast(mask, ms.int32), -2) + dp = sp / n + # (B,A,D) - (B,1,D) = (B,A,D) + momentum -= dp + velocity = momentum * inv_mass + + return velocity + + +class GetDistanceShift(Cell): + r""" + Module for calculating B matrix whose dimensions are C. + + Args: + bonds (Tensor): Tensor of shape (C, 2). Data type is int. + Bonds need to be constraint. + num_atoms (int): Number of atoms in system. + num_walkers (int): Number of multiple walkers. Default: 1 + use_pbc (bool): Whether to use periodic boundary condition. Default: None + + Outputs: + shift (Tensor), Tensor of shape (B,A,D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + bonds: Tensor, + num_atoms: int, + num_walkers: int = 1, + use_pbc: bool = None + ): + + super().__init__(auto_prefix=False) + + # (C,2): + self.bonds = bonds + self.norm = nn.Norm(-1) + + # (B,C,A): + shape = (num_walkers, bonds.shape[-2], num_atoms) + + # (1,C,1): + bond0 = self.bonds[..., 0].reshape(1, -1, 1).asnumpy() + # (B,C,A) <- (B,A,1): + mask0 = np.zeros(shape) + np.put_along_axis(mask0, bond0, 1, axis=-1) + # (B,C,A,1): + self.mask0 = F.expand_dims(Tensor(mask0, ms.int32), -1) + + # (1,C,1): + bond1 = self.bonds[..., 1].reshape(1, -1, 1).asnumpy() + # (B,C,A) <- (B,A,1): + mask1 = np.zeros(shape) + np.put_along_axis(mask1, bond1, 1, axis=-1) + # (B,C,A,1): + self.mask1 = F.expand_dims(Tensor(mask1, ms.int32), -1) + + self.get_distance = GetDistance(use_pbc) + + def construct(self, coordinate_new: Tensor, coordinate_old: Tensor, pbc_box: Tensor = None): + """ + Module for calculating B matrix whose dimensions are C. + + Args: + coordinate_new (Tensor): Tensor of shape (B,A,D). Data type is float. + The new coordinates of the system. + coordinate_old (Tensor): Tensor of shape (B,A,D). Data type is float. + The old coordinates of the system. + pbc_box (Tensor): Tensor of shape (B,D). Data type is float. + Tensor of PBC box + + Return: + shift (Tensor), Tensor of shape (B,A,D). Data type is float. + """ + # (B,C,A,D) = (B,C,A,1) * (B,1,A,D): + pos_new_0 = F.reduce_sum(self.mask0 * coordinate_new, -2) + pos_new_1 = F.reduce_sum(self.mask1 * coordinate_new, -2) + # (B,C,A) + dis_new = self.get_distance(pos_new_0, pos_new_1, pbc_box) + + # (B,C,A,D) = (B,C,A,1) * (B,1,A,D): + pos_old_0 = F.reduce_sum(self.mask0 * coordinate_old, -2) + pos_old_1 = F.reduce_sum(self.mask1 * coordinate_old, -2) + dis_old = self.get_distance(pos_old_0, pos_old_1, pbc_box) + + # (B,C,A) + return dis_new - dis_old + + +class GetShiftGrad(Cell): + """ + Module for calculating the differentiation of B matrix whose dimensions are: K*N*D. + + Args: + bonds (Tensor): Tensor of shape (K, N, D). Data type is int. + Bonds need to be constraint. + num_atoms (int): Number of atoms in system. + num_walkers (int): Number of multiple walkers. Default: 1 + dimension (int): Number of dimension. Default: 3 + use_pbc (bool): Whether to use periodic boundary condition. + + Outputs: + shift (Tensor), Tensor of shape (B,A,D). Data type is float. + + Symbol: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms in system. + N: Number of neighbour atoms. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + num_atoms: int, + bonds: Tensor, + num_walkers: int = 1, + dimension: int = 3, + use_pbc: bool = None + ): + + super().__init__(auto_prefix=False) + + # (B,K,A,D): + shape = (num_walkers, bonds.shape[-2], num_atoms, dimension) + self.broadcast = ops.BroadcastTo(shape) + self.net = GetDistanceShift( + bonds=bonds, + num_atoms=num_atoms, + num_walkers=num_walkers, + use_pbc=use_pbc + ) + + self.grad = ops.GradOperation() + self.zero_shift = ops.Zeros()((num_walkers, num_atoms - 1, num_atoms, dimension), ms.float32) + + def construct(self, coordinate_new: Tensor, coordinate_old: Tensor, pbc_box: Tensor = None): + """ + Module for calculating the differentiation of B matrix whose dimensions are: K*N*D. + + Args: + coordinate_new (Tensor): Tensor of shape (B,A,D). Data type is float. + The new coordinates of the system. + coordinate_old (Tensor): Tensor of shape (B,A,D). Data type is float. + The old coordinates of the system. + pbc_box (Tensor): Tensor of shape (B,D). Data type is float. + Tensor of PBC box. + + Return: + shift (Tensor), Tensor of shape (B,A,D). Data type is float. + """ + # (B,C,A,D): + coordinate_new = self.broadcast(coordinate_new[:, None, :, :]) + coordinate_old = self.broadcast(coordinate_old[:, None, :, :]) + shift_grad = self.grad(self.net)(coordinate_new, coordinate_old, pbc_box) + if msnp.isnan(shift_grad.sum()): + shift_grad = self.zero_shift + return shift_grad diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/function/units.py b/MindSPONGE/applications/research/Grasp/mindsponge1/function/units.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c241135ce5a173707726d8c224ab615885cb82 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/function/units.py @@ -0,0 +1,1054 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Units +""" + +import math + +__all__ = [ + 'AVOGADRO_NUMBER', + 'BOLTZMANN_CONSTANT', + 'GAS_CONSTANT', + 'ELEMENTARY_CHARGE', + 'VACCUM_PERMITTIVITY', + 'COULOMB_CONSTANT', + 'STANDARD_ATMOSPHERE', + 'Length', + 'Energy', + 'get_length_ref', + 'get_length_unit', + 'get_length_unit_name', + 'get_energy_ref', + 'get_energy_unit', + 'get_energy_unit_name', + 'length_convert', + 'energy_convert', + 'Units', + 'global_units', + 'set_global_length_unit', + 'set_global_energy_unit', + 'set_global_units', +] + +# origin constant +AVOGADRO_NUMBER = 6.02214076e23 # N_A +BOLTZMANN_CONSTANT = 1.380649e-23 # kB +GAS_CONSTANT = 8.31446261815324 # R unit=1/mol +ELEMENTARY_CHARGE = 1.602176634e-19 # e unit=C +VACCUM_PERMITTIVITY = 8.854187812813e-12 # \epsilon_0 +COULOMB_CONSTANT = 8.9875517923e9 # k = 1/(4*pi*\epsilon_0) unit=N*m^2/C^2 +STANDARD_ATMOSPHERE = 101325 # atm + +_LENGTH_UNITS = ( + 'nm', + 'um', + 'a', + 'angstrom', + 'bohr', + 'user', + 'none', +) + +_LENGTH_REF = { + 'nm': 1.0, + 'um': 1e3, + 'a': 0.1, + 'angstrom': 0.1, + 'bohr': 0.052917721067, + 'user': None, + 'none': None, +} + +_LENGTH_NAME = { + 'nm': 'nm', + 'um': 'um', + 'a': 'Angstrom', + 'bohr': 'Bohr', + 'user': 'User_Length', + 'none': "None" +} + +_ENERGY_UNITS = ( + 'kj/mol', + 'j/mol', + 'kcal/mol', + 'cal/mol', + 'ha', + 'ev', + 'mev' + 'kbt0', + 'kbt300', + 'user', + 'none', +) + +_ENERGY_REF = { + 'kj/mol': 1.0, + 'j/mol': 1e-3, + 'kcal/mol': 4.184, + 'cal/mol': 4.184e-3, + 'ha': 2625.5002, + 'ev': 96.48530749925793, + 'mev': 0.09648530749925793, + 'kbt0': 2.271095464, + 'kbt300': 2.494338785, + 'user': None, + 'none': None, +} + +_ENERGY_NAME = { + 'kj/mol': 'kJ mol-1', + 'j/mol': 'J mol-1', + 'kcal/mol': 'kcal mol-1', + 'cal/mol': 'cal mol-1', + 'ha': 'Hartree', + 'ev': 'eV', + 'mev': 'meV', + 'kbt0': 'kBT(273.15K)', + 'kbt300': 'kBT(300K)', + 'user': 'User_Energy', + 'none': 'None', +} + +# Boltzmann constant for simulation (kJ/mol) +_BOLTZMANN_DEFAULT_REF = 8.31446261815324e-3 +# Coulomb constant for simulation (e^2*kJ/mol*nm) +# N_A*e^2/(4*pi*\epsilon_0)*1e9nm[1m]*1e-3kJ[1J] +_COULOMB_DEFAULT_REF = 138.93545764498226165718756672623 +# Pressure 1 Bar in kJ mol-1 nm^3 +_BAR_DEFAULT_REF = 16.6053906717384685 + + +class Length: + """ + Length. + + Args: + value (float): length value. + unit (str): length value unit. Default: 'nm' + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, value: float, unit: str = 'nm'): + if isinstance(value, Length): + self.__value = value.value + self.__unit = value.unit + self.__ref = value.ref + self.__abs_size = value.abs_size + self.__unit_name = value.unit_name + elif isinstance(value, (float, int)): + self.__value = float(value) + if isinstance(unit, (str, Units)): + self.__unit = get_length_unit(unit) + self.__ref = get_length_ref(unit) + elif isinstance(unit, (float, int)): + self.__unit = 'user' + self.__ref = float(unit) + else: + raise TypeError( + 'Unsupported length unit type: ' + str(type(unit))) + self.__abs_size = self.__value * self.__ref + self.__unit_name = get_length_unit_name(self.__unit) + else: + raise TypeError( + 'Unsupported length value type: ' + str(type(value))) + + def change_unit(self, unit): + """ + change unit. + + Args: + unit (str): Energy unit. + """ + if isinstance(unit, (str, Units)): + self.__unit = get_length_unit(unit) + self.__ref = get_length_ref(unit) + elif isinstance(unit, (float, int)): + self.__unit = 'user' + self.__ref = unit + else: + raise TypeError('Unsupported length unit type: ' + str(type(unit))) + self.__value = length_convert('nm', unit) * self.__abs_size + self.__unit_name = get_length_unit_name(self.__unit) + return self + + @property + def abs_size(self): + """ + absolute size of length. + + Returns: + float, the absolute size of length. + """ + return self.__abs_size + + @property + def value(self): + """ + value of length. + + Returns: + float, the value of length. + """ + return self.__value + + @property + def ref(self): + """ + reference value. + + Returns: + float, a reference value. + """ + return self.__ref + + @property + def unit(self): + """ + length unit. + + Returns: + str, the length unit. + """ + return self.__unit + + @property + def unit_name(self): + """ + name of length unit. + + Returns: + str, the name of length unit. + """ + return self.__unit_name + + def __call__(self, unit=None): + return self.__value * length_convert(self.__unit, unit) + + def __str__(self): + return str(self.__value) + ' ' + self.__unit_name + + def __lt__(self, other): + if isinstance(other, Length): + return self.__abs_size < other.abs_size + return self.__value < other + + def __gt__(self, other): + if isinstance(other, Length): + return self.__abs_size > other.abs_size + return self.__value > other + + def __eq__(self, other): + if isinstance(other, Length): + return self.__abs_size == other.abs_size + return self.__value == other + + def __le__(self, other): + if isinstance(other, Length): + return self.__abs_size <= other.abs_size + return self.__value <= other + + def __ge__(self, other): + if isinstance(other, Length): + return self.__abs_size >= other.abs_size + return self.__value >= other + + +class Energy: + """ + Energy. + + Args: + value (float): energy value. + unit (str): energy value unit. Default: 'kl/mol' + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, value: float, unit: str = 'kj/mol'): + if isinstance(value, Energy): + self.__value = value.value + self.__unit = value.unit + self.__ref = value.ref + self.__abs_size = value.abs_size + self.__unit_name = value.unit_name + elif isinstance(value, (float, int)): + self.__value = float(value) + if isinstance(unit, (str, Units)): + self.__unit = get_energy_unit(unit) + self.__ref = get_energy_ref(unit) + elif isinstance(unit, (float, int)): + self.__unit = 'user' + self.__ref = float(unit) + else: + raise TypeError('Unsupported energy unit type: ' + str(type(unit))) + self.__abs_size = self.__value * self.__ref + self.__unit_name = get_energy_unit_name(self.__unit) + else: + raise TypeError( + 'Unsupported energy value type: ' + str(type(value))) + + def change_unit(self, unit): + """ + change unit. + + Args: + unit (str): Energy unit. + """ + if isinstance(unit, (str, Units)): + self.__unit = get_energy_unit(unit) + self.__ref = get_energy_ref(unit) + elif isinstance(unit, (float, int)): + self.__unit = 'user' + self.__ref = unit + else: + raise TypeError('Unsupported energy unit type: ' + str(type(unit))) + self.__value = energy_convert('kj/mol', unit) * self.__abs_size + self.__unit_name = get_energy_unit_name(self.__unit) + return self + + def __call__(self, unit=None): + return self.__value * energy_convert(self.__unit, unit) + + def __str__(self): + return str(self.__value) + ' ' + self.__unit_name + + def __lt__(self, other): + if isinstance(other, Energy): + return self.__abs_size < other.abs_size + return self.__value < other + + def __gt__(self, other): + if isinstance(other, Energy): + return self.__abs_size > other.abs_size + return self.__value > other + + def __eq__(self, other): + if isinstance(other, Energy): + return self.__abs_size == other.abs_size + return self.__value == other + + def __le__(self, other): + if isinstance(other, Energy): + return self.__abs_size <= other.abs_size + return self.__value <= other + + def __ge__(self, other): + if isinstance(other, Energy): + return self.__abs_size >= other.abs_size + return self.__value >= other + + @property + def abs_size(self): + """ + absolute size of energy. + + Returns: + float, the absolute size of energy. + """ + return self.__abs_size + + @property + def value(self): + """ + value of energy. + + Returns: + float, the value of energy. + """ + return self.__value + + @property + def ref(self): + """ + reference value. + + Returns: + float, the reference value of energy. + """ + return self.__ref + + @property + def unit(self): + """ + energy unit. + + Returns: + str, the unit of energy value. + """ + return self.__unit + + @property + def unit_name(self): + """ + name of energy unit. + + Returns: + str, the name of energy unit. + """ + return self.__unit_name + + +def get_length_ref(unit): + """ + get length reference. + + Args: + unit (Union[str, Units, Length, float, int]): Length unit. + + Returns: + length reference(Union[str, float, int]). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if unit is None: + return None + if isinstance(unit, str): + if unit.lower() not in _LENGTH_REF.keys(): + raise KeyError('length unit "' + unit + '" is not recorded!') + return _LENGTH_REF.get(unit.lower()) + if isinstance(unit, Units): + return unit.length_ref + if isinstance(unit, Length): + return unit.ref + if isinstance(unit, (float, int)): + return unit + raise TypeError('Unsupported length reference type: ' + str(type(unit))) + + +def get_length_unit(unit): + """ + get length unit. + + Args: + unit (Union[str, Units, Length, float, int]): Length unit. + + Returns: + length unit(Union[str, float, int]). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if unit is None: + return 'none' + if isinstance(unit, str): + if unit.lower() not in _LENGTH_UNITS: + raise KeyError('length unit "' + unit + '" is not recorded!') + return unit.lower() + if isinstance(unit, Units): + return unit.length_unit + if isinstance(unit, Length): + return unit.unit + if isinstance(unit, (float, int)): + return 'user' + raise TypeError('Unsupported length unit type: ' + str(type(unit))) + + +def get_length_unit_name(unit): + """ + get name of length unit. + + Args: + unit (Union[str, Units, Length, float, int]): Length unit. + + Returns: + length unit(str). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if unit is None: + return 'None' + if isinstance(unit, str): + if unit.lower() not in _LENGTH_NAME.keys(): + raise KeyError('length unit "' + unit + '" is not recorded!') + return _LENGTH_NAME.get(unit.lower()) + if isinstance(unit, Units): + return unit.length_unit_name + if isinstance(unit, Length): + return unit.unit_name + if isinstance(unit, (float, int)): + return 'User_Length' + raise TypeError('Unsupported length unit name type: ' + str(type(unit))) + + +def get_energy_ref(unit): + """ + get energy reference. + + Args: + unit (Union[str, Units, Length, float, int]): Energy unit. + + Returns: + energy reference(Union[str, float, int]). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if unit is None: + return None + if isinstance(unit, str): + if unit.lower() not in _ENERGY_REF.keys(): + raise KeyError('energy unit "' + unit + '" is not recorded!') + return _ENERGY_REF.get(unit.lower()) + if isinstance(unit, Units): + return unit.energy_ref + if isinstance(unit, Energy): + return unit.ref + if isinstance(unit, (float, int)): + return unit + raise TypeError('Unsupported energy reference type: ' + str(type(unit))) + + +def get_energy_unit(unit): + """ + get energy unit. + + Args: + unit (Union[str, Units, Length, float, int]): Energy unit. + + Returns: + energy unit(str). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if unit is None: + return 'none' + if isinstance(unit, str): + if unit.lower() not in _ENERGY_UNITS: + raise KeyError('energy unit "' + unit + '" is not recorded!') + return unit.lower() + if isinstance(unit, Units): + return unit.energy_unit + if isinstance(unit, Energy): + return unit.unit + if isinstance(unit, (float, int)): + return 'user' + raise TypeError('Unsupported energy unit type: ' + str(type(unit))) + + +def get_energy_unit_name(unit): + """ + get the name of energy unit. + + Args: + unit (Union[str, Units, Length, float, int]): Energy unit. + + Returns: + name of energy unit(str). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + if unit is None: + return 'None' + if isinstance(unit, str): + if unit.lower() not in _ENERGY_NAME.keys(): + raise KeyError('energy unit "' + unit + '" is not recorded!') + return _ENERGY_NAME.get(unit.lower()) + if isinstance(unit, Units): + return unit.energy_unit_name + if isinstance(unit, Energy): + return unit.unit_name + if isinstance(unit, (float, int)): + return 'User_Energy' + raise TypeError('Unsupported energy unit name type: ' + str(type(unit))) + + +def length_convert(unit_in, unit_out): + """ + convert length according to different units. + + Args: + unit_in (Union[str, Units, Length, float, int]): input unit of length. + unit_out (Union[str, Units, Length, float, int]): output unit of length. + + Returns: + float, length according to different units. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + length_in = get_length_ref(unit_in) + length_out = get_length_ref(unit_out) + if length_in is None or length_out is None: + return 1 + return length_in / length_out + + +def energy_convert(unit_in, unit_out): + """ + convert energy according to difference units. + + Args: + unit_in (Union[str, Units, Length, float, int]): input unit of length. + unit_out (Union[str, Units, Length, float, int]): output unit of length. + + Returns: + float, energy according to different units. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + energy_in = get_energy_ref(unit_in) + energy_out = get_energy_ref(unit_out) + if energy_in is None or energy_out is None: + return 1 + return energy_in / energy_out + + +class Units: + r""" + Unit class to record and convert the length and energy units. + + Args: + length_unit (str): Length unit. Default: None + energy_unit (str): Energy unit. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + length_unit: str = None, + energy_unit: str = None, + ): + + self.__length_unit = get_length_unit(length_unit) + self.__length_unit_name = get_length_unit_name(length_unit) + self.__length_ref = get_length_ref(length_unit) + + self.__energy_unit = get_energy_unit(energy_unit) + self.__energy_unit_name = get_energy_unit_name(energy_unit) + self.__energy_ref = get_energy_ref(energy_unit) + + self.__boltzmann = _BOLTZMANN_DEFAULT_REF + if self.__energy_ref is not None: + self.__boltzmann /= self.__energy_ref + self.__coulomb = _COULOMB_DEFAULT_REF + if self.__length_ref is not None and self.__energy_ref is not None: + self.__coulomb /= self.__energy_ref * self.__length_ref + + self.time_unit = 'ps' + + def set_length_unit(self, unit=None): + """ + set length unit. + + Args: + unit (str): Length unit. + """ + if unit is not None: + self.__length_unit = get_length_unit(unit) + self.__length_unit_name = get_length_unit_name(unit) + self.__length_ref = get_length_ref(unit) + self._set_constants() + return self + + def set_energy_unit(self, unit=None): + """ + set energy unit. + + Args: + unit (str): Energy unit. + """ + if unit is not None: + self.__energy_unit = get_energy_unit(unit) + self.__energy_unit_name = get_energy_unit_name(unit) + self.__energy_ref = get_energy_ref(unit) + self._set_constants() + return self + + def set_units(self, length_unit, energy_unit, units=None): + """ + set unit. + + Args: + length_unit (str): Length unit. + energy_unit (str): Energy unit. + units (str): Units. + """ + if units is None: + if length_unit is not None: + self.__length_unit = get_length_unit(length_unit) + self.__length_unit_name = get_length_unit_name(length_unit) + self.__length_ref = get_length_ref(length_unit) + if energy_unit is not None: + self.__energy_unit = get_energy_unit(energy_unit) + self.__energy_unit_name = get_energy_unit_name(energy_unit) + self.__energy_ref = get_energy_ref(energy_unit) + else: + if not isinstance(units, Units): + raise TypeError('The type of units must be "Units"') + self.__length_unit = get_length_unit(units) + self.__length_unit_name = get_length_unit_name(units) + self.__length_ref = get_length_ref(units) + self.__energy_unit = get_energy_unit(units) + self.__energy_unit_name = get_energy_unit_name(units) + self.__energy_ref = get_energy_ref(units) + return self._set_constants() + + def _set_constants(self): + """set constant values""" + self.__boltzmann = _BOLTZMANN_DEFAULT_REF + if self.__energy_ref is not None: + self.__boltzmann /= self.__energy_ref + self.__coulomb = _COULOMB_DEFAULT_REF + if self.__length_ref is not None and self.__energy_ref is not None: + self.__coulomb /= self.__energy_ref * self.__length_ref + return self + + def length(self, value, unit=None): + """ + return the length value of the specified unit. + + Args: + value (float): Length value. + unit (str): Length unit. + + Returns: + float, the length value. + """ + return value * self.convert_length_from(unit) + + def energy(self, value, unit=None): + """ + return the energy value of the specified unit. + + Args: + value (float): Energy value. + unit (str): Energy unit. + + Returns: + float, the energy value. + """ + return value * self.convert_energy_from(unit) + + def convert_length_to(self, unit): + """ + convert length to a specified units. + + Args: + unit (str): Length unit. + + Returns: + float, length according to a specified units. + """ + return length_convert(self.__length_unit, unit) + + def convert_energy_to(self, unit): + """ + convert energy to a specified units. + + Args: + unit (str): Energy unit. + + Returns: + float, energy according to a specified units. + """ + return energy_convert(self.__energy_unit, unit) + + def convert_length_from(self, unit): + """convert length from a specified units. + + Args: + unit (str): Length unit. + + Returns: + float, length according from a specified units. + """ + return length_convert(unit, self.__length_unit) + + def convert_energy_from(self, unit): + """ + convert energy from a specified units. + + Args: + unit (str): Energy unit. + + Returns: + float, energy according from a specified units. + """ + return energy_convert(unit, self.__energy_unit) + + @property + def boltzmann_def(self): + """ + Boltzmann constant in kJ/mol. + + Returns: + float, Boltzmann constant in kJ/mol. + """ + return _BOLTZMANN_DEFAULT_REF + + @property + def boltzmann(self): + """ + Boltzmann constant in current unit. + + Returns: + float, Boltzmann constant in current unit. + """ + return self.__boltzmann + + @property + def coulomb(self): + """ + Coulomb constant in current unit. + + Returns: + float, Coulomb constant in current unit. + """ + return self.__coulomb + + @property + def avogadro(self): + """ + Avogadro number. + + Returns: + float, Avogadro number. + """ + return AVOGADRO_NUMBER + + @property + def gas_constant(self): + """ + gas constant. + + Returns: + float, gas constant. + """ + return GAS_CONSTANT + + @property + def length_unit(self): + """ + length unit. + + Returns: + length unit (Union[str, float, int]). + """ + return self.__length_unit + + @property + def energy_unit(self): + """ + energy unit. + + Returns: + energy unit (Union[str, float, int]). + """ + return self.__energy_unit + + @property + def length_unit_name(self): + """ + name of length unit. + + Returns: + str, name of length unit. + """ + return self.__length_unit_name + + @property + def energy_unit_name(self): + """ + name of energy unit. + + Returns: + str, name of energy unit. + """ + return self.__energy_unit_name + + @property + def volume_unit(self): + """ + velocity unit. + + Returns: + str, velocity unit. + """ + return self.__length_unit + "^3" + + @property + def volume_unit_name(self): + """ + velocity unit name. + + Returns: + str, velocity unit name. + """ + return self.__length_unit + "+3" + + @property + def force_unit(self): + """ + force unit. + + Returns: + str, force unit. + """ + return self.__energy_unit + '/' + self.__length_unit + + @property + def force_unit_name(self): + """ + name of force unit. + + Returns: + str, name of force unit. + """ + return self.__energy_unit_name + ' ' + self.__length_unit_name + '-1' + + @property + def velocity_unit(self): + """ + velocity unit. + + Returns: + str, velocity unit. + """ + return self.__length_unit + "/" + self.time_unit + + @property + def velocity_unit_name(self): + """ + name of velocity unit. + + Returns: + str, name of velocity unit. + """ + return self.__length_unit_name + ' ' + self.time_unit + '-1' + + @property + def length_ref(self): + """ + reference value of length. + + Returns: + float, reference value of length. + """ + return self.__length_ref + + @property + def energy_ref(self): + """ + reference value of energy. + + Returns: + float, reference value of energy. + """ + return self.__energy_ref + + @property + def force_ref(self): + """ + reference value of force. + + Returns: + float, reference value of force. + """ + if self.__energy_ref is None: + return None + return self.__energy_ref / self.__length_ref + + @property + def acceleration_ref(self): + """ + reference value of acceleration. + + Returns: + float, reference value of acceleration. + """ + if self.__energy_ref is None or self.__length_ref is None: + return None + return self.__energy_ref / self.__length_ref / self.__length_ref + + @property + def kinetic_ref(self): + """ + reference value of kinetic. + + Returns: + float, reference value of kinetic. + """ + if self.__energy_ref is None or self.__length_ref is None: + return None + return self.__length_ref * self.__length_ref / self.__energy_ref + + @property + def pressure_ref(self): + if self.__energy_ref is None or self.__length_ref is None: + return None + return _BAR_DEFAULT_REF * self.__energy_ref / math.pow(self.__length_ref, 3) + + +global_units = Units('nm', 'kj/mol') + + +def set_global_length_unit(unit): + """ + set length unit for global_units. + + Args: + unit (str): Length unit. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + global_units.set_length_unit(unit) + + +def set_global_energy_unit(unit): + """ + set energy unit for global_units. + + Args: + unit (str): Energy unit. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + global_units.set_energy_unit(unit) + + +def set_global_units(length_unit, energy_unit, units=None): + """ + set units for global_units. + + Args: + length_unit (str): Length unit. Default: None + energy_unit (str): Energy unit. Default: None + units (str): units. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + global_units.set_units(length_unit, energy_unit, units) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f46eceaa0f06253787d4e3be481153fbcb85b775 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Metrics""" + +from .metrics import CV, BalancedMSE, BinaryFocal, MultiClassFocal +from .structure_violations import between_residue_bond, between_residue_clash +from .structure_violations import within_residue_violations, get_structural_violations +from .structure_violations import compute_renamed_ground_truth, frame_aligned_point_error_map +from .structure_violations import backbone, frame_aligned_point_error, sidechain +from .structure_violations import supervised_chi, local_distance_difference_test + +__all__ = ['CV', 'BalancedMSE', 'BinaryFocal', 'MultiClassFocal', "between_residue_bond", + "between_residue_clash", "within_residue_violations", "get_structural_violations", + "compute_renamed_ground_truth", "frame_aligned_point_error_map", + "backbone", "frame_aligned_point_error", "sidechain", "supervised_chi", + "local_distance_difference_test"] + +__all__.sort() diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/metrics.py b/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..41438d76ae67beb083b66cccb62442b736b03e96 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/metrics.py @@ -0,0 +1,369 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Metrics for collective variables +""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.communication.management as D +import mindspore.nn as nn +import mindspore.numpy as mnp + +from mindspore import Parameter, Tensor +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.nn import Metric + +from ..colvar import Colvar + + +class CV(Metric): + """Metric to output collective variables""" + def __init__(self, + colvar: Colvar, + indexes: tuple = (2, 3), + ): + + super().__init__() + self._indexes = indexes + self.colvar = colvar + self._cv_value = None + + def clear(self): + self._cv_value = 0 + + def update(self, *inputs): + coordinate = inputs[self._indexes[0]] + pbc_box = inputs[self._indexes[1]] + self._cv_value = self.colvar(coordinate, pbc_box) + + def eval(self): + return self._cv_value + + +class BalancedMSE(nn.Cell): + r""" + Balanced MSE error + Compute Balanced MSE error between the prediction and the ground truth + to solve unbalanced labels in regression task. + + Reference: + `Ren, Jiawei, et al. 'Balanced MSE for Imbalanced Visual Regression' `_ . + + .. math:: + L =-\log \mathcal{N}\left(\boldsymbol{y} ; \boldsymbol{y}_{\text {pred }}, + \sigma_{\text {noise }}^{2} \mathrm{I}\right) + +\log \sum_{i=1}^{N} p_{\text {train }}\left(\boldsymbol{y}_{(i)}\right) + \cdot \mathcal{N}\left(\boldsymbol{y}_{(i)} ; \boldsymbol{y}_{\text {pred }}, + \sigma_{\text {noise }}^{2} \mathrm{I}\right) + + Args: + first_break (float): The begin value of bin. + last_break (float): The end value of bin. + num_bins (int): The bin numbers. + beta (float): The moving average coefficient, default: 0.99. + reducer_flag (bool): Whether to aggregate the label values of multiple devices, default: "False". + + Inputs: + - **prediction** (Tensor) - Predict values, shape is :math:`(batch\_size, ndim)`. + - **target** (Tensor) - Label values, shape is :math:`(batch\_size, ndim)`. + + Outputs: + Tensor, shape is :math:`(batch\_size, ndim)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.metrics import BalancedMSE + >>> from mindspore import Tensor + >>> net = BalancedMSE(0, 1, 20) + >>> prediction = Tensor(np.random.randn(32, 10).astype(np.float32)) + >>> target = Tensor(np.random.randn(32, 10).astype(np.float32)) + >>> out = net(prediction, target) + >>> print(out.shape) + (32, 10) + """ + + def __init__(self, first_break, last_break, num_bins, beta=0.99, reducer_flag=False): + super(BalancedMSE, self).__init__() + self.beta = beta + self.first_break = first_break + self.last_break = last_break + self.num_bins = num_bins + + self.breaks = mnp.linspace(self.first_break, self.last_break, self.num_bins) + self.width = self.breaks[1] - self.breaks[0] + bin_width = 2 + start_n = 1 + stop = self.num_bins * 2 + centers = mnp.divide(mnp.arange(start=start_n, stop=stop, step=bin_width), num_bins * 2.0) + self.centers = centers/(self.last_break-self.first_break) + self.first_break + + self.log_noise_scale = Parameter(Tensor([0.], mstype.float32)) + self.p_bins = Parameter(Tensor(np.ones((self.num_bins)) / self.num_bins, dtype=mstype.float32), \ + name='p_bins', requires_grad=False) + + self.softmax = nn.Softmax(-1) + self.zero = Tensor([0.]) + + self.onehot = nn.OneHot(depth=self.num_bins) + self.reducer_flag = reducer_flag + if self.reducer_flag: + self.allreduce = P.AllReduce() + self.device_num = D.get_group_size() + + def construct(self, prediction, target): + """construct""" + p_bins = self._compute_p_bins(prediction) + + log_sigma2 = self.log_noise_scale * 1. + log_sigma2 = 5. * P.Tanh()(log_sigma2 / 5.) + sigma2 = mnp.exp(log_sigma2) + 0.25 * self.width + tau = 2. * sigma2 + a = - F.square(prediction - target) / tau + + ndim = prediction.ndim + y_bins = mnp.reshape(self.centers * 1., ndim * (1,) + (-1,)) + b_term = - F.square(mnp.expand_dims(prediction, -1) - y_bins) / tau + + p_clip = mnp.clip(p_bins, 1e-8, 1 - 1e-8) + log_p = mnp.log(p_clip) + log_p = mnp.reshape(log_p, ndim * (1,) + (-1,)) + + b_term += log_p + b = nn.ReduceLogSumExp(-1, False)(b_term) + + err = -a + b + return err + + def _compute_p_bins(self, y_gt): + """compute bins""" + ndim = y_gt.ndim + breaks = mnp.reshape(self.breaks, (1,) * ndim + (-1,)) + y_gt = mnp.expand_dims(y_gt, -1) + + y_bins = (y_gt > breaks).astype(mstype.float32) + y_bins = P.ReduceSum()(y_bins, -1).astype(mstype.int32) + p_gt = self.onehot(y_bins) + + p_gt = P.Reshape()(p_gt, (-1, self.num_bins)) + p_bins = P.ReduceMean()(p_gt, 0) + if self.reducer_flag: + p_bins = self.allreduce(p_bins) / self.device_num + + p_bins = self.beta * self.p_bins + (1 - self.beta) * p_bins + P.Assign()(self.p_bins, p_bins) + + return p_bins + + +class MultiClassFocal(nn.Cell): + r"""Focal error for multi-class classifications. + Compute the multiple classes focal error between `prediction` and the ground truth `target`. + Reference: + `Lin, Tsung-Yi, et al. 'Focal loss for dense object detection' `_ . + + Args: + num_class (int): The class numbers. + beta (float): The moving average coefficient, default: 0.99. + gamma (float): The hyperparameters, default: 2.0. + e (float): The proportion of focal loss, default: 0.1. + neighbors(int): The neighbors to be mask in the target, default 2. + not_focal (bool): Whether focal loss, default: "False". + reducer_flag (bool): Whether to aggregate the label values of multiple devices, default: "False". + + Inputs: + - **prediction** (Tensor) - Predict values, shape is :math:`(batch\_size, ndim)`. + - **target** (Tensor) - Label values, shape is :math:`(batch\_size, ndim)`. + + Outputs: + Tensor, shape is :math:`(batch\_size,)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindsponge.metrics import MultiClassFocal + >>> net = MultiClassFocal(10) + >>> prediction = Tensor(np.random.randn(32, 10).astype(np.float32)) + >>> target = Tensor(np.random.randn(32, 10).astype(np.float32)) + >>> out = net(prediction, target) + >>> print(out.shape) + (32,) + """ + + def __init__(self, num_class, beta=0.99, gamma=2., e=0.1, neighbors=2, not_focal=False, reducer_flag=False): + super(MultiClassFocal, self).__init__() + self.num_class = num_class + self.beta = beta + self.gamma = gamma + self.e = e + self.neighbors = neighbors + self.not_focal = not_focal + + neighbor_mask = np.ones((self.num_class, self.num_class)) + neighbor_mask = neighbor_mask - np.triu(neighbor_mask, neighbors) - np.tril(neighbor_mask, -neighbors) + neighbor_mask = neighbor_mask / (np.sum(neighbor_mask, axis=-1, keepdims=True) + 1e-10) + self.neighbor_mask = Tensor(neighbor_mask, mstype.float32) + + self.class_weights = Parameter(Tensor(np.ones((self.num_class)) / self.num_class, dtype=mstype.float32), \ + name='class_weights', requires_grad=False) + + self.softmax = nn.Softmax(-1) + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + self.zero = Tensor([0.]) + + self.reducer_flag = reducer_flag + if self.reducer_flag: + self.allreduce = P.AllReduce() + + def construct(self, prediction, target): + """construct""" + prediction_tensor = self.softmax(prediction) + + zeros = mnp.zeros_like(prediction_tensor) + one_minus_p = mnp.where(target > 1e-5, target - prediction_tensor, zeros) + ft = -1 * mnp.power(one_minus_p, self.gamma) * mnp.log(mnp.clip(prediction_tensor, 1e-8, 1.0)) + + classes_num = self._compute_classes_num(target) + total_num = mnp.sum(classes_num) + + classes_w_t1 = total_num / classes_num + sum_ = mnp.sum(classes_w_t1) + classes_w_t2 = classes_w_t1 / sum_ + classes_w_tensor = F.cast(classes_w_t2, mstype.float32) + + weights = self.beta * self.class_weights + (1 - self.beta) * classes_w_tensor + P.Assign()(self.class_weights, weights) + + classes_weight = mnp.broadcast_to(mnp.expand_dims(weights, 0), target.shape) + alpha = mnp.where(target > zeros, classes_weight, zeros) + + balanced_fl = alpha * ft + balanced_fl = mnp.sum(balanced_fl, -1) + + labels = P.MatMul()(target, self.neighbor_mask) + xent, _ = self.cross_entropy(prediction, target) + + final_loss = (1 - self.e) * balanced_fl + self.e * xent + + if self.not_focal: + softmax_xent, _ = self.cross_entropy(prediction, labels) + final_loss = (1 - self.e) * softmax_xent + self.e * xent + + return final_loss + + def _compute_classes_num(self, target): + "get global classes number" + classes_num = mnp.sum(target, 0) + if self.reducer_flag: + classes_num = self.allreduce(classes_num) + classes_num = F.cast(classes_num, mstype.float32) + classes_num += 1. + return classes_num + + +class BinaryFocal(nn.Cell): + r""" + Focal error for Binary classifications. + Compute the binary classes focal error between `prediction` and the ground truth `target`. + + Reference: + `Lin, Tsung-Yi, et al. 'Focal loss for dense object detection' `_ . + + .. math:: + \mathrm{FL}\left(p_{\mathrm{t}}\right)=-\alpha_{\mathrm{t}}\left(1-p_{\mathrm{t}}\right)^{\gamma} + \log \left(p_{\mathrm{t}}\right) + + Args: + alpha (float): The weight of cross entropy, default: 0.25. + gamma (float): The hyperparameters, modulating loss from hard to easy, default: 2.0. + feed_in (bool): Whether to covert prediction, default: "False". + not_focal (bool): Whether focal loss, default: "False". + + Inputs: + - **prediction** (Tensor) - Predict values, shape is :math:`(batch\_size, ndim)`. + - **target** (Tensor) - Label values, shape is :math:`(batch\_size, ndim)`. + + Outputs: + Tensor, shape is :math:`(batch\_size,)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindsponge.metrics import BinaryFocal + >>> net = BinaryFocal() + >>> prediction = Tensor(np.random.randn(32, 10).astype(np.float32)) + >>> target = Tensor(np.random.randn(32, 10).astype(np.float32)) + >>> out = net(prediction, target) + >>> print(out.shape) + (32,) + """ + + def __init__(self, alpha=0.25, gamma=2., feed_in=False, not_focal=False): + super(BinaryFocal, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.feed_in = feed_in + self.not_focal = not_focal + + self.cross_entropy = P.BinaryCrossEntropy(reduction='none') + self.sigmoid = P.Sigmoid() + self.epsilon = 1e-8 + + def construct(self, prediction, target): + """construct""" + epsilon = self.epsilon + target = F.cast(target, mstype.float32) + probs = F.cast(prediction, mstype.float32) + if self.feed_in: + probs = self.sigmoid(prediction) + else: + prediction = self._convert(prediction) + + ones_tensor = mnp.ones_like(target) + positive_pt = mnp.where(target > 1e-5, probs, ones_tensor) + negative_pt = mnp.where(target < 1e-5, 1 - probs, ones_tensor) + + focal_loss = -self.alpha * mnp.power(1 - positive_pt, self.gamma) * \ + mnp.log(mnp.clip(positive_pt, epsilon, 1.)) - (1 - self.alpha) * \ + mnp.power(1 - negative_pt, self.gamma) * mnp.log(mnp.clip(negative_pt, epsilon, 1.)) + focal_loss *= 2. + + if self.not_focal: + focal_loss = self.cross_entropy(prediction, target, ones_tensor) + + return focal_loss + + def _convert(self, probs): + """convert function""" + probs = mnp.clip(probs, 1e-5, 1. - 1e-5) + prediction = mnp.log(probs / (1 - probs)) + return prediction diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/structure_violations.py b/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/structure_violations.py new file mode 100644 index 0000000000000000000000000000000000000000..1eee89ede4f51df5c5b65c8bf3b9842d4e6bf6df --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/metrics/structure_violations.py @@ -0,0 +1,1228 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Modules and utilities for the structure module.""" +import numpy as np +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import nn, Tensor +from mindspore.ops import operations as P +import mindspore.ops as ops +from ..common.geometry import quaternion_from_tensor +from ..common.utils import find_optimal_renaming +from ..common import residue_constants + + +VIOLATION_TOLERANCE_ACTOR = 12.0 +CLASH_OVERLAP_TOLERANCE = 1.5 + +# one hot encoding for C and N atoms (using atom14 representation) +C_ONE_HOT = Tensor(np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32) +N_ONE_HOT = Tensor(np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32) + +# Van der Waals radii for each atom +ATOMTYPE_RADIUS = \ + np.array([residue_constants.van_der_waals_radius.get(name[0]) for name in residue_constants.atom_types]) +ATOMTYPE_RADIUS = Tensor(ATOMTYPE_RADIUS, ms.float32) +DISTS_MASK_I = Tensor(np.eye(14, 14), ms.int32) + +# lower bound and upper bound between each atoms used for clashes calculation +LOWER_BOUND, UPPER_BOUND, _ = \ + residue_constants.make_atom14_dists_bounds(overlap_tolerance=CLASH_OVERLAP_TOLERANCE, + bond_length_tolerance_factor=VIOLATION_TOLERANCE_ACTOR) +LOWER_BOUND = Tensor(LOWER_BOUND, ms.float32) +UPPER_BOUND = Tensor(UPPER_BOUND, ms.float32) + +CYS_SG_IDX = Tensor(5, ms.int32) + + +def between_residue_bond( + pred_atom_positions, + pred_atom_mask, + residue_index, + aatype, + asym_id, + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0 +): + """ + Flat-bottom loss to penalize structural violations between residues. This is a loss penalizing any violation + of the geometry around the peptide bond between consecutive amino acids. + + Args: + pred_atom_positions (Tensor): Atom positions in atom37/14 representation, shape :math:`(N_{res}, 37, 3)`. + or shape :math:`(N_{res}, 14, 3)` . + pred_atom_mask (Tensor): Atom mask in atom37/14 representation. shape :math:`(N_{res}, 37)` or + shape :math:`(N_{res}, 14)` . + residue_index (Tensor): Residue index for given amino acid, this is assumed to be monotonically + increasing. shape :math:`(N_{res}, )` . + aatype (Tensor): amino acid types. shape :math:`(N_{res}, )` . + tolerance_factor_soft (float): soft tolerance factor measured in standard deviations of pdb distributions. + Default: 12.0 . + tolerance_factor_hard (float): hard tolerance factor measured in standard deviations of pdb distributions. + Default: 12.0 . + + Returns: + - Tensor, c_n_loss_mean, loss for peptide bond length violations. shape is () . + - Tensor, ca_c_n_loss_mean, loss for violations of bond angle around C spanned by CA, C, N. shape is () . + - Tensor, c_n_ca_loss_mean, loss for violations of bond angle around N spanned by C, N, CA. shape is () . + - Tensor, per_residue_loss_sum, sum of all losses of each residue. shape is :math:`(N_{res}, )` . + - Tensor, per_residue_violation_mask, mask denoting all residues with violation present. + shape is :math:`(N_{res}, )` . + + Symbol: + :math:`N_{res}`, number of amino acids. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> import numpy as np + >>> from mindsponge.metrics import between_residue_bond + >>> np.random.seed(1) + >>> pred_atom_positions = Tensor(np.random.random(size=(50,37,3)), ms.float32) + >>> pred_atom_mask = Tensor(np.random.randint(2,size=(50,37)), ms.int32) + >>> residue_index = Tensor(np.array(range(50)), ms.int32) + >>> aatype = Tensor(np.random.randint(20, size=(50,)), ms.int32) + >>> tolerance_factor_soft = 12.0 + >>> tolerance_factor_hard = 12.0 + >>> result = between_residue_bond(pred_atom_positions, pred_atom_mask, residue_index, aatype, + >>> tolerance_factor_soft, tolerance_factor_hard) + >>> for x in result: + >>> print(x) + 0.52967054 + 0.6045412 + 0.39251995 + [0.62809587 1.6770853 1.7221183 1.0325309 1.3417522 1.79882 + 1.7718308 1.5092779 1.5653987 1.9564128 1.6804926 1.6051245 + 1.5033073 1.5895741 2.1686926 2.126039 1.3837843 1.2554975 + 1.8135165 2.1593785 1.9408598 1.7281027 1.8666006 1.9623451 + 1.8177024 1.7543832 1.5969353 1.2150483 0.9833115 1.219868 + 1.7008476 1.6968286 1.7648234 1.5584714 1.370602 1.8525059 + 1.7938454 1.5313196 1.6940074 1.8512855 1.8222975 1.6600168 + 1.9163743 1.7201058 1.6288358 1.6055745 1.521946 1.6553445 + 1.6175683 0.894606 ] + [1. 1. 0. 1. 1. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 1. 1. 0. + 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. + 1. 1.] + + """ + + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[:-1, 1, :] + this_ca_mask = pred_atom_mask[:-1, 1] + this_c_pos = pred_atom_positions[:-1, 2, :] + this_c_mask = pred_atom_mask[:-1, 2] + next_n_pos = pred_atom_positions[1:, 0, :] + next_n_mask = pred_atom_mask[1:, 0] + next_ca_pos = pred_atom_positions[1:, 1, :] + next_ca_mask = pred_atom_mask[1:, 1] + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(ms.float32) + has_no_gap_mask = has_no_gap_mask * (asym_id[..., :-1] == asym_id[..., 1:]).astype(ms.float32) + # Compute loss for the C--N bond. + c_n_bond_length = mnp.sqrt(1e-6 + mnp.sum(mnp.square(this_c_pos - next_n_pos), axis=-1)) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = (aatype[1:] == residue_constants.resname_to_idx['PRO']).astype(ms.float32) + gt_length = ((1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_c_n[1]) + gt_stddev = ((1. - next_is_proline) * residue_constants.between_res_bond_length_stddev_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1]) + c_n_bond_length_error = mnp.sqrt(1e-6 + mnp.square(c_n_bond_length - gt_length)) + c_n_loss_per_residue = nn.ReLU()(c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss_mean = mnp.sum(mask * c_n_loss_per_residue) / (mnp.sum(mask) + 1e-6) + c_n_violation_mask = mask * (c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + ca_c_bond_length = mnp.sqrt(1e-6 + mnp.sum(mnp.square(this_ca_pos - this_c_pos), axis=-1)) + n_ca_bond_length = mnp.sqrt(1e-6 + mnp.sum(mnp.square(next_n_pos - next_ca_pos), axis=-1)) + + c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[:, None] + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[:, None] + n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[:, None] + + ca_c_n_cos_angle = mnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_cos_angles_ca_c_n[1] + ca_c_n_cos_angle_error = mnp.sqrt(1e-6 + mnp.square(ca_c_n_cos_angle - gt_angle)) + ca_c_n_loss_per_residue = nn.ReLU()(ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss_mean = mnp.sum(mask * ca_c_n_loss_per_residue) / (mnp.sum(mask) + 1e-6) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = mnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = mnp.sqrt(1e-6 + mnp.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = nn.ReLU()(c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss_mean = mnp.sum(mask * c_n_ca_loss_per_residue) / (mnp.sum(mask) + 1e-6) + c_n_ca_violation_mask = mask * (c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both neighbouring residues). + per_residue_loss_sum = c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue + per_residue_loss_sum = 0.5 * (mnp.pad(per_residue_loss_sum, [[0, 1]]) + mnp.pad(per_residue_loss_sum, [[1, 0]])) + + # Compute hard violations. + per_residue_violation_mask = mnp.max(mnp.stack([c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask]), + axis=0) + per_residue_violation_mask = mnp.maximum(mnp.pad(per_residue_violation_mask, [[0, 1]]), + mnp.pad(per_residue_violation_mask, [[1, 0]])) + + return c_n_loss_mean, ca_c_n_loss_mean, c_n_ca_loss_mean, per_residue_loss_sum, per_residue_violation_mask + + +def between_residue_clash( + atom14_pred_positions, + atom14_atom_exists, + atom14_atom_radius, + residue_index, + asym_id, + c_one_hot, + n_one_hot, + overlap_tolerance_soft, + overlap_tolerance_hard, + cys_sg_idx): + """ + This is a loss penalizing any steric clashes due to non bonded atoms in different peptides coming too close. + + Args: + atom14_pred_positions (Tensor): predicted positions of atoms in global prediction frame. + shape is :math:`(N_{res}, 14, 3)` . + atom14_atom_exists (Tensor): mask denoting whether atom at positions exists for given amino acid type. + shape is :math:`(N_{res}, 14)` . + atom14_atom_radius (Tensor): Van der Waals radius for each atom. shape is :math:`(N_{res}, 14)` . + residue_index (Tensor): Residue index for given amino acid. shape is :math:`(N_{res}, )` . + c_one_hot (Tensor): one hot encoding for C atoms (using atom14 representation). shape is (14, ) . + n_one_hot (Tensor): one hot encoding for N atoms (using atom14 representation). shape is (14, ) . + overlap_tolerance_soft (float): soft tolerance factor. in default: 12.0. + overlap_tolerance_hard (float): hard tolerance factor. in default: 1.5. + cys_sg_idx (Tensor): CYS amino acid index. Default: 5. + see more at `mindsponge.common.residue_constants`. + + Returns: + - Tensor, mean_loss, average clash loss. Shape is () . + - Tensor, per_atom_loss_sum, sum of all clash losses per atom, shape is :math:`(N_{res}, 14)` . + - Tensor, per_atom_clash_mask, mask whether atom clashes with any other atom, + shape is :math:`(N_{res}, 14)` . + + Symbol: + :math:`N_{res}`, number of amino acids. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> import numpy as np + >>> from mindsponge.metrics import between_residue_clash + >>> atom14_pred_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32) + >>> atom14_atom_exists = Tensor(np.random.randint(2, size=(50, 14))) + >>> atom14_atom_radius = Tensor(np.random.random(size=(50, 14)), ms.float32) + >>> residue_index = Tensor(np.array(range(50)), ms.int32) + >>> c_one_hot = Tensor(np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32) + >>> n_one_hot = Tensor(np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32) + >>> overlap_tolerance_soft = 12.0 + >>> overlap_tolerance_hard = 1.5 + >>> cys_sg_idx = Tensor(5, ms.int32) + >>> mean_loss, per_atom_loss_sum, per_atom_clash_mask = between_residue_clash(atom14_pred_positions, + ... atom14_atom_exists, + ... atom14_atom_radius, + ... residue_index, + ... c_one_hot, + ... n_one_hot, + ... overlap_tolerance_soft, + ... overlap_tolerance_hard, + ... cys_sg_idx) + >>> print(mean_loss.shape, per_atom_loss_sum.shape, per_atom_clash_mask.shape) + () (50,14) (50,14) + + """ + + dists = mnp.sqrt(1e-10 + mnp.sum( + mnp.square(atom14_pred_positions[:, None, :, None, :] - atom14_pred_positions[None, :, None, :, :]), axis=-1)) + dists_mask = atom14_atom_exists[:, None, :, None] * atom14_atom_exists[None, :, None, :] + dists_mask *= (residue_index[:, None, None, None] <= residue_index[None, :, None, None]) + + diagonal = (residue_index[:, None, None, None] == residue_index[None, :, None, None]).astype(ms.float32) + in_one_chain = ( + asym_id[:, None, None, None] == asym_id[None, :, None, None] + ).astype(ms.float32) + diagonal = diagonal * in_one_chain + dists_mask = dists_mask * (1. - diagonal) + # Backbone C--N bond between subsequent residues is no clash. + neighbour_mask = ((residue_index[:, None, None, None] + 1) == residue_index[None, :, None, None]).astype(ms.float32) + neighbour_mask *= (asym_id[:, None, None, None] == asym_id[None, :, None, None]).astype(ms.float32) + c_n_bonds = neighbour_mask * c_one_hot[None, None, :, None] * n_one_hot[None, None, None, :] + dists_mask *= (1. - c_n_bonds) + + + # Disulfide bridge between two cysteines is no clash. + cys_sg_one_hot = nn.OneHot(depth=14)(cys_sg_idx) + disulfide_bonds = (cys_sg_one_hot[None, None, :, None] * cys_sg_one_hot[None, None, None, :]) + dists_mask *= (1. - disulfide_bonds) + + dists_lower_bound = dists_mask * (atom14_atom_radius[:, None, :, None] + atom14_atom_radius[None, :, None, :]) + dists_to_low_error = dists_mask * nn.ReLU()(dists_lower_bound - overlap_tolerance_soft - dists) + mean_loss = mnp.sum(dists_to_low_error) / (1e-6 + mnp.sum(dists_mask)) + per_atom_loss_sum = P.ReduceSum()(dists_to_low_error, (0, 2)) + P.ReduceSum()(dists_to_low_error, (1, 3)) + clash_mask = dists_mask * (dists < (dists_lower_bound - overlap_tolerance_hard)) + per_atom_clash_mask = mnp.maximum(mnp.max(clash_mask, axis=[0, 2]), mnp.max(clash_mask, axis=[1, 3])) + per_atom_clash_count = P.ReduceSum()(clash_mask, (0, 2)) + P.ReduceSum()(clash_mask, (1, 3)) + return mean_loss, per_atom_loss_sum, per_atom_clash_mask, per_atom_clash_count + + +def within_residue_violations( + atom14_pred_positions, + atom14_atom_exists, + atom14_dists_lower_bound, + atom14_dists_upper_bound, + tighten_bounds_for_loss, + dists_mask_i +): + """Loss to penalize steric clashes within residues. + This is a loss penalizing any steric violations or clashes of non-bonded atoms in a given peptide. + + Args: + atom14_pred_positions (Tensor): predicted positions of atoms in global prediction frame. + shape :math:`(N_{res}, 14, 3)` . + atom14_atom_exists (Tensor): mask denoting whether atom at positions exists for given amino acid type. + shape :math:`(N_{res}, 14)` . + atom14_dists_lower_bound (Tensor): lower bond on allowed distances. shape :math:`(N_{res}, 14, 14)` . + atom14_dists_upper_bound (Tensor): upper bond on allowed distances. shape :math:`(N_{res}, 14, 14)` . + tighten_bounds_for_loss (float): Extra factor to tighten loss. Default: 0.0. + dists_mask_i (Tensor): initial distants mask, shape: (14, 14) . + + Returns: + - **per_atom_loss_sum** (Tensor) - sum of all clash losses per atom, shape :math:`(N_{res}, 14)` . + - **per_atom_violations** (Tensor) - violation per atom, shape :math:`(N_{res}, 14)` . + + Symbol: + :math:`N_{res}`, number of amino acids. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> import numpy as np + >>> from mindsponge.metrics import within_residue_violations + >>> atom14_pred_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32) + >>> atom14_atom_exists = Tensor(np.random.random(size=(50, 14)), ms.float32) + >>> atom14_dists_lower_bound = Tensor(np.random.random(size=(50, 14, 14)), ms.float32) + >>> atom14_dists_upper_bound = Tensor(np.random.random(size=(50, 14, 14)), ms.float32) + >>> tighten_bounds_for_loss = 0.0 + >>> dists_mask_i = Tensor(np.eye(14, 14), ms.int32) + >>> per_atom_loss_sum, per_atom_violations = within_residue_violations(atom14_pred_positions, + ... atom14_atom_exists, + ... atom14_dists_lower_bound, + ... atom14_dists_upper_bound, + ... tighten_bounds_for_loss, + ... dists_mask_i) + >>> print(per_atom_loss_sum.shape, per_atom_violations.shape) + (50, 14) (50, 14) + + """ + + dists_masks = (1. - dists_mask_i[None]) + dists_masks *= (atom14_atom_exists[:, :, None] * atom14_atom_exists[:, None, :]) + + dists = mnp.sqrt(1e-10 + mnp.sum( + mnp.square(atom14_pred_positions[:, :, None, :] - atom14_pred_positions[:, None, :, :]), axis=-1)) + dists_to_low_error = nn.ReLU()(atom14_dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = nn.ReLU()(dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + per_atom_loss_sum = mnp.sum(loss, axis=1) + mnp.sum(loss, axis=2) + lower = (dists < atom14_dists_lower_bound).astype(ms.int32) + high = (dists > atom14_dists_upper_bound).astype(ms.int32) + violations = dists_masks * ((lower + high).astype(bool)) + + per_atom_violations = mnp.maximum(mnp.max(violations, axis=1), mnp.max(violations, axis=2)) + per_atom_clash_count = mnp.sum(violations, axis=1) + mnp.sum(violations, axis=2) + return per_atom_loss_sum, per_atom_violations, per_atom_clash_count + + +def get_structural_violations(atom14_atom_exists, residue_index, aatype, residx_atom14_to_atom37, + atom14_pred_positions, asym_id, violation_tolerance_factor=VIOLATION_TOLERANCE_ACTOR, + clash_overlap_tolerance=CLASH_OVERLAP_TOLERANCE, lower_bound=LOWER_BOUND, + upper_bound=UPPER_BOUND, atomtype_radius=ATOMTYPE_RADIUS, + c_one_hot=C_ONE_HOT, n_one_hot=N_ONE_HOT, dists_mask_i=DISTS_MASK_I, + cys_sg_idx=CYS_SG_IDX): + """Computes several checks for structural violations. + + Args: + atom14_atom_exists (Tensor): mask denoting whether atom at positions exists for given amino acid type. + shape :math:`(N_{res}, 14)` . + residue_index (Tensor): Residue index for given amino acid. shape :math:`(N_{res}, )` . + aatype (Tensor): amino acid types. shape :math:`(N_{res}, )` . + residx_atom14_to_atom37 (Tensor): mapping for (residx, atom14) --> atom37. shape :math:`(N_{res}, 14)` . + atom14_pred_positions (Tensor): predicted positions of atoms in global prediction frame. + shape :math:`(N_{res}, 14, 3)` . + violation_tolerance_factor (float): violation between amino acid tolerance factor. Default: 12.0 . + clash_overlap_tolerance (float): clash overlap tolerance factor. Default: 1.5 . + lower_bound (Tensor): lower bond on allowed distances. shape :math:`(N_{res}, 14, 14)` . + upper_bound (Tensor): upper bond on allowed distances. shape :math:`(N_{res}, 14, 14)` . + atomtype_radius (Tensor): Van der Waals radius for each amino acid. shape: (37, ) . + c_one_hot (Tensor): one hot encoding for C atoms (using atom14 representation). shape: (14, ) . + n_one_hot (Tensor): one hot encoding for N atoms (using atom14 representation). shape: (14, ) . + dists_mask_i (Tensor): initial distants mask, shape: (14, 14) . + cys_sg_idx (Tensor): CYS amino acid index. Default: 5 . + see more at `mindsponge.common.residue_constants`. + + Returns: + - bonds_c_n_loss_mean (Tensor), loss for peptide bond length violations. shape is () . + - angles_ca_c_n_loss_mean (Tensor), loss for violations of bond angle around C spanned by CA, C, N. shape is (). + - angles_c_n_ca_loss_mean (Tensor), loss for violations of bond angle around N spanned by C, N, CA. shape is (). + - connections_per_residue_loss_sum (Tensor), sum of all losses of each residue. shape is :math:`(N_{res}, )` . + - connections_per_residue_violation_mask (Tensor), mask denoting all residues with violation present. + shape is :math:`(N_{res}, )` . + - clashes_mean_loss (Tensor), average clash loss. shape: () . + - clashes_per_atom_loss_sum (Tensor), sum of all clash losses per atom, shape :math:`(N_{res}, 14)` . + - clashes_per_atom_clash_mask (Tensor), mask whether atom clashes with any other atom. + shape :math:`(N_{res}, 14)` . + - per_atom_loss_sum (Tensor), sum of all clash losses per atom, shape :math:`(N_{res}, 14)` . + - per_atom_violations (Tensor), violation per atom, shape :math:`(N_{res}, 14)` . + - total_per_residue_violations_mask (Tensor), violation masks for all residues, shape :math:`(N_{res}, )` . + - structure_violation_loss (Tensor), total violations for all amino acids. shape is () . + + Symbol: + :math:`N_{res}`, number of amino acids. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> import numpy as np + >>> from mindsponge.metrics import get_structural_violations + >>> atom14_atom_exists = Tensor(np.random.random(size=(50, 14)), ms.float32) + >>> residue_index = Tensor(np.array(range(50)), ms.int32) + >>> aatype = Tensor(np.random.randint(20, size=(50,)), ms.int32) + >>> residx_atom14_to_atom37 = Tensor(np.random.randint(2, size=(50, 14)), ms.int32) + >>> atom14_pred_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32) + >>> violation_tolerance_factor = 12.0 + >>> clash_overlap_tolerance = 1.5 + >>> lower_bound = Tensor(np.random.random(size=(50, 14, 14)), ms.float32) + >>> upper_bound = Tensor(np.random.random(size=(50, 14, 14)), ms.float32) + >>> atomtype_radius =Tensor([1.55, 1.7, 1.7, 1.7, 1.52, 1.7, 1.7, 1.7, 1.52, 1.52, 1.8, + ... 1.7, 1.7, 1.7, 1.55, 1.55, 1.52, 1.52, 1.8, 1.7, 1.7, 1.7, + ... 1.7, 1.55, 1.55, 1.55, 1.52, 1.52, 1.7, 1.55, 1.55, 1.52, 1.7, + ... 1.7, 1.7, 1.55, 1.52], ms.float32) + >>> c_one_hot = Tensor(np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32) + >>> n_one_hot = Tensor(np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32) + >>> dists_mask_i = Tensor(np.eye(14, 14), ms.int32) + >>> cys_sg_idx = Tensor(5, ms.int32) + >>> result = get_structural_violations(atom14_atom_exists, residue_index, aatype, residx_atom14_to_atom37, + ... atom14_pred_positions, violation_tolerance_factor, + ... clash_overlap_tolerance, lower_bound, upper_bound, atomtype_radius, + ... c_one_hot, n_one_hot, dists_mask_i,cys_sg_idx) + >>> for r in result: + >>> print(r.shape) + () + () + () + (50,) + (50,) + () + (50, 14) + (50, 14) + (50, 14) + (50, 14) + (50,) + () + + """ + + # Compute between residue backbone violations of bonds and angles. + c_n_loss_mean, ca_c_n_loss_mean, c_n_ca_loss_mean, per_residue_loss_sum, per_residue_violation_mask = \ + between_residue_bond( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=atom14_atom_exists.astype(mnp.float32), + residue_index=residue_index.astype(mnp.float32), + aatype=aatype, + asym_id=asym_id, + tolerance_factor_soft=violation_tolerance_factor, + tolerance_factor_hard=violation_tolerance_factor) + # Compute the Van der Waals radius for every atom (the first letter of the atom name is the element type). + # Shape: (N, 14). + atom14_atom_radius = atom14_atom_exists * P.Gather()(atomtype_radius, residx_atom14_to_atom37, 0) + + # Compute the between residue clash loss. + mean_loss, clashes_per_atom_loss_sum, per_atom_clash_mask, clashes_per_atom_clash_count = between_residue_clash( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=atom14_atom_exists, + atom14_atom_radius=atom14_atom_radius, + residue_index=residue_index, + asym_id=asym_id, + c_one_hot=c_one_hot, + n_one_hot=n_one_hot, + overlap_tolerance_soft=clash_overlap_tolerance, + overlap_tolerance_hard=clash_overlap_tolerance, + cys_sg_idx=cys_sg_idx + ) + # mean_loss, clashes_per_atom_loss_sum, per_atom_clash_mask, clashes_per_atom_clash_count=0, 0, 0, 0 + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + atom14_dists_lower_bound = P.Gather()(lower_bound, aatype, 0) + atom14_dists_upper_bound = P.Gather()(upper_bound, aatype, 0) + per_atom_loss_sum, per_atom_violations, per_atom_clash_count = within_residue_violations( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=atom14_atom_exists, + atom14_dists_lower_bound=atom14_dists_lower_bound, + atom14_dists_upper_bound=atom14_dists_upper_bound, + tighten_bounds_for_loss=0.0, + dists_mask_i=dists_mask_i) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = mnp.max(mnp.stack([per_residue_violation_mask, mnp.max(per_atom_clash_mask, axis=-1), + mnp.max(per_atom_violations, axis=-1)]), axis=0) + bonds_c_n_loss_mean = c_n_loss_mean + angles_ca_c_n_loss_mean = ca_c_n_loss_mean + angles_c_n_ca_loss_mean = c_n_ca_loss_mean + connections_per_residue_loss_sum = per_residue_loss_sum + connections_per_residue_violation_mask = per_residue_violation_mask + clashes_mean_loss = mean_loss + clashes_per_atom_loss_sum = clashes_per_atom_loss_sum + clashes_per_atom_clash_mask = per_atom_clash_mask + per_atom_loss_sum = per_atom_loss_sum + per_atom_violations = per_atom_violations + total_per_residue_violations_mask = per_residue_violations_mask + num_atoms = P.ReduceSum()(atom14_atom_exists.astype(ms.float32)) + structure_violation_loss = bonds_c_n_loss_mean + angles_ca_c_n_loss_mean + angles_c_n_ca_loss_mean +\ + P.ReduceSum()(clashes_per_atom_loss_sum + per_atom_loss_sum) / (1e-6 + num_atoms) + return (bonds_c_n_loss_mean, angles_ca_c_n_loss_mean, angles_c_n_ca_loss_mean, connections_per_residue_loss_sum, + connections_per_residue_violation_mask, clashes_mean_loss, clashes_per_atom_loss_sum, + clashes_per_atom_clash_mask, per_atom_loss_sum, per_atom_violations, total_per_residue_violations_mask, + structure_violation_loss, clashes_per_atom_clash_count, per_atom_clash_count) + + +def compute_renamed_ground_truth(atom14_gt_positions, + atom14_alt_gt_positions, + atom14_atom_is_ambiguous, + atom14_gt_exists, + atom14_pred_positions, + atom14_alt_gt_exists): + """ + Find optimal renaming of ground truth based on the predicted positions. + + Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + Shape (N). + + Args: + atom14_gt_positions (Tensor): Ground truth positions. shape :math:`(N_{res}, 14, 3)` . + atom14_alt_gt_positions (Tensor): Ground truth positions with renaming swaps. shape :math:`(N_{res}, 14, 3)` . + atom14_atom_is_ambiguous (Tensor): 1.0 for atoms that are affected by renaming swaps. + shape :math:`(N_{res}, 14)` . + atom14_gt_exists (Tensor): Mask for which atoms exist in ground truth. shape :math:`(N_{res}, 14)` . + atom14_pred_positions (Tensor): Array of atom positions in global frame with shape :math:`(N_{res}, 14, 3)` . + atom14_alt_gt_exists (Tensor): Mask for which atoms exist in ground truth after renaming. + shape :math:`(N_{res}, 14)` . + + Returns: + - **alt_naming_is_better** (Tensor) - Array with 1.0 where alternative swap is better. + shape :math:`(N_{res}, )` . + - **renamed_atom14_gt_positions** (Tensor) - Array of optimal ground truth positions after renaming swaps are + performed. shape :math:`(N_{res}, 14, 3)` . + - **renamed_atom14_gt_exists** (Tensor) - Mask after renaming swap is performed. shape :math:`(N_{res}, 14)` . + + Symbol: + :math:`N_{res}`, number of amino acids. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> import numpy as np + >>> from mindsponge.metrics import compute_renamed_ground_truth + >>> atom14_gt_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32) + >>> atom14_alt_gt_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32) + >>> atom14_atom_is_ambiguous = Tensor(np.random.random(size=(50, 14)), ms.float32) + >>> atom14_gt_exists = Tensor(np.random.random(size=(50, 14)), ms.float32) + >>> atom14_pred_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32) + >>> atom14_alt_gt_exists = Tensor(np.random.random(size=(50, 14)), ms.float32) + >>> alt_naming_is_better, renamed_atom14_gt_positions, renamed_atom14_gt_exists = \ + ... compute_renamed_ground_truth(atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, + ... atom14_gt_exists, atom14_pred_positions, atom14_alt_gt_exists) + >>> print(alt_naming_is_better.shape, renamed_atom14_gt_positions.shape, renamed_atom14_gt_exists.shape) + (50,) (50, 14, 3) (50, 14) + + """ + + alt_naming_is_better = find_optimal_renaming(atom14_gt_positions, + atom14_alt_gt_positions, + atom14_atom_is_ambiguous, + atom14_gt_exists, + atom14_pred_positions) + + renamed_atom14_gt_positions = ((1. - alt_naming_is_better[:, None, None]) * atom14_gt_positions + + alt_naming_is_better[:, None, None] * atom14_alt_gt_positions) + + renamed_atom14_gt_mask = ((1. - alt_naming_is_better[:, None]) * atom14_gt_exists + + alt_naming_is_better[:, None] * atom14_alt_gt_exists) + + return alt_naming_is_better, renamed_atom14_gt_positions, renamed_atom14_gt_mask + + +def frame_aligned_point_error_map(pred_frames, + target_frames, + frames_mask, + pred_positions, + target_positions, + positions_mask, + length_scale, + l1_clamp_distance, + pair_mask, + sbr_mask): + r"""Measure point error under different alignments which computes error between two + structures with B points under A alignments derived from the given pairs of frames. + Similar with the `frame_aligned_point_error` function. The difference is this is a + batched version which return batch error for each group of local frames individually, + this version considers only backbone frames. + + Args: + pred_frames (list): The predicted backbone frames which is a 2-dimensional list, + the first element of pred_frames is a list of 9 tensors which are the 9 components of + rotation matrix; the second element of pred_frames is a list of 3 tensors are the 3 + component of translation matrix. All tensors are of shape :math:`(N_{recycle}, N_{res})`. + with :math:`N_{recycle}` the recycle number of FoldIteration in Structure module, :math:`N_{res}` the + number of residues in protein. + target_frames (list): The ground truth backbone frames which is also a 2-dimensional + list, the same as pred_frames except that the shape of tensors is :math:`(N_{res},)`. + frames_mask (Tensor): The binary mask for frames of shape :math:`(N_{res},)`. + pred_positions (list): The predicted Ca atom positions which is a list of 3 + tensors of shape :math:`(N_{recycle}, N_{res},)`. + target_positions (list): The ground truth Ca atom positions which is a list + of 3 tensors of shape :math:`(N_{res},)`. + positions_mask (Tensor): The binary mask for Ca atom positions of shape :math:`(N_{res},)`. + length_scale (float): The unit distance which is used to scale distances. + l1_clamp_distance (float): Distance cutoff on error beyond which gradients will + be zero. + + Returns: + - **error_clamp** (Tensor) - Backbone FAPE loss clamped with shape :math:`(N_{recycle},)`. + - **error_no_clamp** (Tensor) - Backbone FAPE loss (not clamped) with shape :math:`(N_{recycle},)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.metrics import frame_aligned_point_error_map + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> np.random.seed(0) + >>> rot_matrix = [[Tensor(np.random.rand(8, 256)).astype(mstype.float32) for _ in range(9)]] + >>> trans_matrix = [[Tensor(np.random.rand(8, 256)).astype(mstype.float32) for _ in range(3)]] + >>> pred_frames = rot_matrix + trans_matrix + >>> rot_matrix = [[Tensor(np.random.rand(256,)).astype(mstype.float32) for _ in range(9)]] + >>> trans_matrix = [[Tensor(np.random.rand(256,)).astype(mstype.float32) for _ in range(3)]] + >>> target_frames = rot_matrix + trans_matrix + >>> frames_mask = Tensor(np.random.rand(256,)).astype(mstype.float32) + >>> positions_mask = Tensor(np.random.rand(256,)).astype(mstype.float32) + >>> pred_positions = [Tensor(np.random.rand(8, 256)).astype(mstype.float32) for _ in range(3)] + >>> target_positions = [Tensor(np.random.rand(256,)).astype(mstype.float32) for _ in range(3)] + >>> length_scale = 10.0 + >>> l1_clamp_distance = 10.0 + >>> error, error_noclamp = frame_aligned_point_error_map(pred_frames, target_frames, frames_mask, + ... pred_positions, target_positions, positions_mask, + ... length_scale, l1_clamp_distance) + >>> print(error, error_noclamp) + [0.0827449 0.08608595 0.09045469 0.08518302 0.08452212 0.08624027 0.08426301 0.08154671] + [0.0827449 0.08608595 0.09045469 0.08518302 0.08452212 0.08624027 0.08426301 0.08154671] + """ + + # Compute array of predicted positions in the predicted frames. + xx = pred_frames[0][0] + xy = pred_frames[0][1] + xz = pred_frames[0][2] + yx = pred_frames[0][3] + yy = pred_frames[0][4] + yz = pred_frames[0][5] + zx = pred_frames[0][6] + zy = pred_frames[0][7] + zz = pred_frames[0][8] + t0_p = pred_frames[1][0] + t1_p = pred_frames[1][1] + t2_p = pred_frames[1][2] + t0 = pred_positions[0] + t1 = pred_positions[1] + t2 = pred_positions[2] + + v1 = -(xx * t0_p + yx * t1_p + zx * t2_p) + v2 = -(xy * t0_p + yy * t1_p + zy * t2_p) + v3 = -(xz * t0_p + yz * t1_p + zz * t2_p) + + local_pred_pos = [ + xx[..., None] * t0[:, None, ...] + yx[..., None] * t1[:, None, ...] + zx[..., None] * t2[:, None, ...] + v1[ + ..., None], + xy[..., None] * t0[:, None, ...] + yy[..., None] * t1[:, None, ...] + zy[..., None] * t2[:, None, ...] + v2[ + ..., None], + xz[..., None] * t0[:, None, ...] + yz[..., None] * t1[:, None, ...] + zz[..., None] * t2[:, None, ...] + v3[ + ..., None] + ] + xx_gt = target_frames[0][0] + xy_gt = target_frames[0][1] + xz_gt = target_frames[0][2] + yx_gt = target_frames[0][3] + yy_gt = target_frames[0][4] + yz_gt = target_frames[0][5] + zx_gt = target_frames[0][6] + zy_gt = target_frames[0][7] + zz_gt = target_frames[0][8] + t0_t = target_frames[1][0] + t1_t = target_frames[1][1] + t2_t = target_frames[1][2] + t0_gt = target_positions[0] + t1_gt = target_positions[1] + t2_gt = target_positions[2] + + v1_gt = -(xx_gt * t0_t + yx_gt * t1_t + zx_gt * t2_t) + v2_gt = -(xy_gt * t0_t + yy_gt * t1_t + zy_gt * t2_t) + v3_gt = -(xz_gt * t0_t + yz_gt * t1_t + zz_gt * t2_t) + + epsilon = 1e-4 + + local_target_pos = [xx_gt[:, None] * t0_gt[None, :] + yx_gt[:, None] * t1_gt[None, :] + + zx_gt[:, None] * t2_gt[None, :] + v1_gt[:, None], xy_gt[:, None] * t0_gt[None, :] + + yy_gt[:, None] * t1_gt[None, :] + zy_gt[:, None] * t2_gt[None, :] + + v2_gt[:, None], xz_gt[:, None] * t0_gt[None, :] + yz_gt[:, None] * t1_gt[None, :] + + zz_gt[:, None] * t2_gt[None, :] + v3_gt[:, None]] + error_dist = mnp.sqrt(ops.Square()(local_pred_pos[0] - local_target_pos[0]) + + ops.Square()(local_pred_pos[1] - local_target_pos[1]) + + ops.Square()(local_pred_pos[2] - local_target_pos[2]) + epsilon) + + + all_mask = ops.expand_dims(frames_mask, axis=-1) * ops.expand_dims(positions_mask, axis=-2) + + all_mask = all_mask * pair_mask + normalization_factor = mnp.sum(all_mask) + + # fape with clamp + error_dist_clamp = mnp.clip(error_dist, 0, l1_clamp_distance) + normed_error_clamp = error_dist_clamp / length_scale + error_clamp = P.ReduceSum()(normed_error_clamp * all_mask, (-2, -1)) / (epsilon + normalization_factor) + + # fape with no clamp + normed_error_no_clamp = error_dist / length_scale + error_no_clamp = P.ReduceSum()(normed_error_no_clamp * all_mask, (-2, -1)) / (epsilon + normalization_factor) + + # sbr fape with no clamp + sbr_mask = all_mask * sbr_mask + sbr_fape_clamp = P.ReduceSum()(normed_error_clamp * sbr_mask, (-2, -1)) / (epsilon + mnp.sum(sbr_mask)) + + return error_clamp, error_no_clamp, sbr_fape_clamp + + +def backbone(traj, backbone_affine_tensor, backbone_affine_mask, fape_clamp_distance, fape_loss_unit_distance, + use_clamped_fape, asym_id, sbr_mask): + r""" + Backbone FAPE Loss using `frame_aligned_point_error_map` function. + `Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17 + `_. + + Args: + traj (Tensor): The series of backbone frames(trajectory) generated by Structure + module, the shape is :math:`(N_{recycle}, N_{res}, 7)` with :math:`(N_{recycle},)` the + recycle number of recycle in Structure module, :math:`(N_{res},)` the number of residues + in protein, for the last dimension, the first 4 elements are the affine tensor which + contains the rotation information, the last 3 elements are the translations in space. + backbone_affine_tensor (Tensor): The ground truth backbone frames of shape :math:`(N_{res}, 7)`. + backbone_affine_mask (Tensor): The binary mask for backbone frames of shape :math:`(N_{res},)`. + fape_clamp_distance (float): Distance cutoff on error beyond which gradients will + be zero. + fape_loss_unit_distance (float): The unit distance of backbone FAPE loss, used to scale + distances. + use_clamped_fape (float): The indicator that if backbone FAPE loss is clamped, + 0 or 1, 1 means clamping. + + Returns: + - **fape** (Tensor) - Backbone FAPE loss (clamped if use_clamped_fape is 1) of last recycle + of Structure module with shape (). + - **loss** (Tensor) - Averaged Backbone FAPE loss (clamped if use_clamped_fape is 1) of all recycle of + Structure module with shape (). + - **no_clamp** (Tensor) - Backbone FAPE loss of last recycle of Structure module with shape (). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> np.random.seed(0) + >>> from mindsponge.metrics import backbone + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> traj = Tensor(np.random.rand(8, 256, 7)).astype(mstype.float32) + >>> backbone_affine_tensor = Tensor(np.random.rand(256, 7)).astype(mstype.float32) + >>> backbone_affine_mask = Tensor(np.random.rand(256,)).astype(mstype.float16) + >>> fape_clamp_distance = 10.0 + >>> fape_loss_unit_distance = 10.0 + >>> use_clamped_fape = 1 + >>> fape, loss, noclamp = backbone(traj, backbone_affine_tensor, backbone_affine_mask, + ... fape_clamp_distance, fape_loss_unit_distance, use_clamped_fape) + >>> print(fape, loss, noclamp) + 0.12813742 0.12904957 0.12813742 + """ + + + _, rotation, translation = quaternion_from_tensor(traj) + pred_frames = ((rotation[0], rotation[1], rotation[2], + rotation[3], rotation[4], rotation[5], + rotation[6], rotation[7], rotation[8]), + (translation[0], translation[1], translation[2])) + pred_positions = [translation[0], translation[1], translation[2]] + + _, rotation_gt, translation_gt = quaternion_from_tensor(backbone_affine_tensor) + target_frames = ((rotation_gt[0], rotation_gt[1], rotation_gt[2], + rotation_gt[3], rotation_gt[4], rotation_gt[5], + rotation_gt[6], rotation_gt[7], rotation_gt[8]), + (translation_gt[0], translation_gt[1], translation_gt[2])) + target_positions = [translation_gt[0], translation_gt[1], translation_gt[2]] + + frames_mask = backbone_affine_mask + positions_mask = backbone_affine_mask + + intra_chain_mask = P.Cast()(asym_id[:, None] == asym_id[None, :], ms.float32) + + fape_loss_clamp_intra, fape_loss_no_clamp_intra, sbr_fape_clamp_intra\ + = frame_aligned_point_error_map(pred_frames, + target_frames, + frames_mask, + pred_positions, + target_positions, + positions_mask, + fape_clamp_distance, + fape_loss_unit_distance, + intra_chain_mask, + sbr_mask) + + fape_loss_clamp_interface, fape_loss_no_clamp_interface, sbr_fape_clamp_interface\ + = frame_aligned_point_error_map(pred_frames, + target_frames, + frames_mask, + pred_positions, + target_positions, + positions_mask, + 20.0, + 30.0, + 1-intra_chain_mask, + sbr_mask) + + fape_loss_clamp = fape_loss_clamp_interface + fape_loss_clamp_intra + + fape_loss_no_clamp = fape_loss_no_clamp_interface + fape_loss_no_clamp_intra + + fape_loss = (fape_loss_clamp * use_clamped_fape + fape_loss_no_clamp * (1 - use_clamped_fape)) + no_clamp = fape_loss_no_clamp[-1] + fape = fape_loss[-1] + loss = mnp.mean(fape_loss) + return fape, loss, no_clamp, fape_loss_no_clamp_intra[-1], fape_loss_no_clamp_interface[-1], \ + sbr_fape_clamp_intra[-1], sbr_fape_clamp_interface[-1] + + +def frame_aligned_point_error(pred_frames, + target_frames, + frames_mask, + pred_positions, + target_positions, + positions_mask, + length_scale, + l1_clamp_distance): + r""" + Measure point error under different alignments which computes error between two + structures with B points under A alignments derived from the given pairs of frames. + `Jumper et al. (2021) Suppl. Alg. 28 "computeFAPE" + `_. + This function considers all frames. + First transform the predicted atom positions to different predicted local frames, + :math:`\vec{x_{j\_pred}^{i}} = \mathcal{T}_{i\_{pred}} \circ \vec{x_{j\_pred}}` + Then transform the true atom positions to different true local frames, + :math:`\vec{x_{j\_gt}^{i}} = \mathcal{T}_{i\_{gt}} \circ \vec{x_{j\_gt}}` + Then compute the L2 error of all atoms positions in all local frames. + :math:`\sum_{i }^{N_{frames}}\sum_{j}^{N_{atoms}}(\parallel \vec{x_{j\_pred}^{i}} - + \vec{x_{j\_gt}^{i}} \parallel )` + + Args: + pred_frames (Tensor): The predicted frames of shape :math:`(12, N_{frames})` with + :math:`N_{frames}` the number of pairs of frames. For the first dimension, the first + 9 elements are the 9 components of rotation matrix; the last 3 elements are + the 3 component of translation matrix. + target_frames (Tensor): The ground truth frames of same shape as pred_frames. + frames_mask (Tensor): The binary mask for frames of shape :math:`(N_{frames},)`. + pred_positions (Tensor): The predicted atom positions tensor of shape + :math:`(3, N_{atoms})` with :math:`N_{atoms}` the number of atoms. + target_positions (Tensor): The ground truth atom positions of same shape as + pred_positions. + positions_mask (Tensor): The binary mask for atom positions of shape :math:`(N_{atoms},)`. + length_scale (float): The unit distance which is used to scale distances. + l1_clamp_distance (float): Distance cutoff on error beyond which gradients will + be zero. + + Returns: + - **error_clamp** (Tensor) - Backbone FAPE loss clamped with shape :math:`(N_{recycle},)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> np.random.seed(0) + >>> from mindsponge.metrics import frame_aligned_point_error + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> pred_frames = Tensor(np.random.rand(12, 256)).astype(mstype.float32) + >>> target_frames = Tensor(np.random.rand(12, 256)).astype(mstype.float32) + >>> frames_mask = Tensor(np.random.rand(256,)).astype(mstype.float32) + >>> pred_positions = Tensor(np.random.rand(3, 1024)).astype(mstype.float32) + >>> target_positions = Tensor(np.random.rand(3, 1024)).astype(mstype.float32) + >>> positions_mask = Tensor(np.random.rand(1024,)).astype(mstype.float32) + >>> length_scale = 10.0 + >>> l1_clamp_distance = 10.0 + >>> fape = frame_aligned_point_error(pred_frames, target_frames, frames_mask, + >>> pred_positions, target_positions, positions_mask, + >>> length_scale, l1_clamp_distance) + >>> print(fape) + 0.08747593 + """ + + # Compute array of predicted positions in the predicted frames. + xx = pred_frames[0] + xy = pred_frames[1] + xz = pred_frames[2] + yx = pred_frames[3] + yy = pred_frames[4] + yz = pred_frames[5] + zx = pred_frames[6] + zy = pred_frames[7] + zz = pred_frames[8] + t0_p = pred_frames[9] + t1_p = pred_frames[10] + t2_p = pred_frames[11] + t0 = pred_positions[0] + t1 = pred_positions[1] + t2 = pred_positions[2] + + v1 = -(xx * t0_p + yx * t1_p + zx * t2_p) + v2 = -(xy * t0_p + yy * t1_p + zy * t2_p) + v3 = -(xz * t0_p + yz * t1_p + zz * t2_p) + + local_pred_pos = [ + xx[..., None] * t0[None, ...] + yx[..., None] * t1[None, ...] + zx[..., None] * t2[None, ...] + v1[..., None], + xy[..., None] * t0[None, ...] + yy[..., None] * t1[None, ...] + zy[..., None] * t2[None, ...] + v2[..., None], + xz[..., None] * t0[None, ...] + yz[..., None] * t1[None, ...] + zz[..., None] * t2[None, ...] + v3[..., None] + ] + xx_gt = target_frames[0] + xy_gt = target_frames[1] + xz_gt = target_frames[2] + yx_gt = target_frames[3] + yy_gt = target_frames[4] + yz_gt = target_frames[5] + zx_gt = target_frames[6] + zy_gt = target_frames[7] + zz_gt = target_frames[8] + t0_t = target_frames[9] + t1_t = target_frames[10] + t2_t = target_frames[11] + t0_gt = target_positions[0] + t1_gt = target_positions[1] + t2_gt = target_positions[2] + + v1_gt = -(xx_gt * t0_t + yx_gt * t1_t + zx_gt * t2_t) + v2_gt = -(xy_gt * t0_t + yy_gt * t1_t + zy_gt * t2_t) + v3_gt = -(xz_gt * t0_t + yz_gt * t1_t + zz_gt * t2_t) + + epsilon = 1e-4 + local_target_pos = [xx_gt[:, None] * t0_gt[None, :] + yx_gt[:, None] * t1_gt[None, :] + + zx_gt[:, None] * t2_gt[None, :] + v1_gt[:, None], xy_gt[:, None] * t0_gt[None, :] + + yy_gt[:, None] * t1_gt[None, :] + zy_gt[:, None] * t2_gt[None, :] + + v2_gt[:, None], xz_gt[:, None] * t0_gt[None, :] + yz_gt[:, None] * t1_gt[None, :] + + zz_gt[:, None] * t2_gt[None, :] + v3_gt[:, None]] + error_dist = mnp.sqrt(ops.Square()(local_pred_pos[0] - local_target_pos[0]) + + ops.Square()(local_pred_pos[1] - local_target_pos[1]) + + ops.Square()(local_pred_pos[2] - local_target_pos[2]) + epsilon) + if l1_clamp_distance: + error_dist = mnp.clip(error_dist, 0, l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error *= ops.expand_dims(frames_mask, axis=-1) + normed_error *= ops.expand_dims(positions_mask, axis=-2) + + normalization_factor = mnp.sum(frames_mask, axis=-1) * mnp.sum(positions_mask, axis=-1) + return mnp.sum(normed_error, axis=(-2, -1)) / (epsilon + normalization_factor) + + +def sidechain(alt_naming_is_better, rigidgroups_gt_frames, rigidgroups_alt_gt_frames, rigidgroups_gt_exists, + renamed_atom14_gt_positions, renamed_atom14_gt_exists, sidechain_atom_clamp_distance, + sidechain_length_scale, pred_frames, pred_positions): + r""" + sidechain FAPE Loss which take all local frames (side-chain, backbone) into consideration. + `Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 28 + `_. + + Args: + alt_naming_is_better (Tensor): Tensor of shape :math:`(N_{res},)`, with value 1.0 where alternative + swap is better. + rigidgroups_gt_frames (Tensor): The ground truth locals frames of shape :math:`(N_{res}, 8, 12)`, + with :math:`(N_{res},)` the number of residues in protein. For each residue, there are 1 backbone + frame and 7 side-chain frames, 8 frames in total. For the last dimension, the first 9 elements + are the 9 components of rotation matrix; the last 3 elements are the 3 component of + translation matrix. + rigidgroups_alt_gt_frames (Tensor): The alternative ground truth locals frames due to + symmetry of amino acids. This tensor has the same shape as rigidgroups_gt_frames + rigidgroups_gt_exists (Tensor): The binary mask for gt frames of shape :math:`(N_{res}, 8)`. + renamed_atom14_gt_positions (Tensor): The mask for ground truth positions after renaming + swaps are performed(swaps are needed for some amino acids due to symmetry + `compute_renamed_ground_truth`), its shape is :math:`(N_{res}, 14)`.It takes the 14-types + atoms encoding. + renamed_atom14_gt_exists (Tensor): The mask for ground truth positions after renaming + swap is performed after renaming swaps are performed, its shape is :math:`(N_{res}, 14)`. + sidechain_atom_clamp_distance (float): Distance cutoff on error beyond which gradients + will be zero. + sidechain_length_scale (float): The unit distance of sidechain FAPE loss, used to scale + distances. + pred_frames (Tensor): The predicted locals frames of shape :math:`(12, N_{recycle}, N_{res}, 8)`. + :math:`(N_{recycle},)` is the recycle number of FoldIteration in Structure module. Only the frames of + last recycle is used in side-chain FAPE loss. 12 has the same meaning as the third dimension of + rigidgroups_gt_frames. + pred_positions (Tensor): The predicted positions of shape :math:`(3, N_{recycle}, N_{res}, 14)`. + Only the positions of last recycle is used in side-chain FAPE loss, encoded atom-14 encoding. + + Returns: + Tensor, fape. Clamped side-chian FAPE loss with shape (). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> np.random.seed(0) + >>> from mindsponge.metrics import sidechain + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> alt_naming_is_better = Tensor(np.zeros((256,))).astype(mstype.float32) + >>> rigidgroups_gt_frames = Tensor(np.random.rand(256, 8, 12)).astype(mstype.float32) + >>> rigidgroups_alt_gt_frames = Tensor(np.random.rand(256, 8, 12)).astype(mstype.float32) + >>> rigidgroups_gt_exists = Tensor(np.random.rand(256, 8)).astype(mstype.float32) + >>> renamed_atom14_gt_positions = Tensor(np.random.rand(256, 14, 3)).astype(mstype.float32) + >>> renamed_atom14_gt_exists = Tensor(np.random.rand(256, 14)).astype(mstype.float32) + >>> sidechain_atom_clamp_distance = 10.0 + >>> sidechain_length_scale = 10.0 + >>> pred_frames = Tensor(np.random.rand(12, 8, 256, 8)).astype(mstype.float32) + >>> pred_positions = Tensor(np.random.rand(3, 8, 256, 14)).astype(mstype.float32) + >>> sidechain_loss = sidechain(alt_naming_is_better, rigidgroups_gt_frames, rigidgroups_alt_gt_frames, + ... rigidgroups_gt_exists, renamed_atom14_gt_positions, + ... renamed_atom14_gt_exists, sidechain_atom_clamp_distance,sidechain_length_scale, + ... pred_frames, pred_positions) + >>> print(sidechain_loss) + 0.08569459 + """ + # Rename Frames + # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7 + renamed_gt_frames = ((1. - alt_naming_is_better[:, None, None]) * rigidgroups_gt_frames + + alt_naming_is_better[:, None, None] * rigidgroups_alt_gt_frames) + flat_gt_frames = mnp.moveaxis(mnp.reshape(renamed_gt_frames, [-1, 12]), -1, 0) + flat_frames_mask = mnp.reshape(rigidgroups_gt_exists, [-1]) + + flat_gt_positions_t = mnp.moveaxis(mnp.reshape(renamed_atom14_gt_positions, [-1, 3]), -1, 0) + flat_positions_mask = mnp.reshape(renamed_atom14_gt_exists, [-1]) + + # Compute frame_aligned_point_error score for the final layer. + flat_pred_frames = mnp.reshape(pred_frames[:, -1, ...], [12, -1]) + flat_pred_positions = mnp.reshape(pred_positions[:, -1, ...], [3, -1]) + + # FAPE Loss on sidechains + fape = frame_aligned_point_error( + pred_frames=flat_pred_frames, + target_frames=flat_gt_frames, + frames_mask=flat_frames_mask, + pred_positions=flat_pred_positions, + target_positions=flat_gt_positions_t, + positions_mask=flat_positions_mask, + l1_clamp_distance=sidechain_atom_clamp_distance, + length_scale=sidechain_length_scale) + return fape + + +def supervised_chi(sequence_mask, aatype, sin_cos_true_chi, torsion_angle_mask, sin_cos_pred_chi, + sin_cos_unnormalized_pred, chi_weight, angle_norm_weight, chi_pi_periodic): + r"""Computes loss for direct chi angle supervision. The torsion angles are represented by + the sine and cosine value of the angle. This loss is composed of 2 items, the error of + normalized predicted sine and cosine value, called chi angle difference loss; the other + term is the difference between L2 norm of sine cosine value and 1, called angle norm loss. + `Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss" + `_. + + Args: + sequence_mask (Tensor): The mask tensor for sequence of shape :math:`(N_{res},)` + with :math:`N_{res}` the number of residues in protein. + aatype (Tensor): The amino acid type tensor of shape :math:`(N_{res},)`. + sin_cos_true_chi (Tensor): Tensor of shape :math:`(N_{res}, 14)` which is the sine + and cosine value of torsion angles. There are 7 torsion angles per residue, + 3 for backbone and 4 for sidechain. + torsion_angle_mask (Tensor): The binary mask for sidechain torsion angles of shape + :math:`(N_{res}, 4)` + sin_cos_pred_chi (Tensor): The predicted sine and cosine value (normalized) + of torsion angles of shape :math:`(N_{res}, 4, 2)`. + sin_cos_unnormalized_pred (Tensor): The predicted sine and cosine value (unnormalized) + of torsion angles of shape :math:`(N_{recycle}, N_{res}, 7, 2)` with :math:`N_{recycle}` + is the recycle number of FoldIteration in Structure module. + chi_weight (float): The weight of chi angle difference loss term, constant. + angle_norm_weight (float): The weight of angle norm loss term, constant. + chi_pi_periodic (Tensor): Chi angles that are pi periodic: they can be rotated + by a multiple of pi without affecting the structure. Constants of residues of shape + :math:`(21, 4)`, 20 types of amino acids + unknown. + + Returns: + - **loss** (Tensor) - Supervised chi angle loss with shape :math:`()` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> np.random.seed(0) + >>> from mindsponge.metrics import supervised_chi + >>> from mindsponge.common import residue_constants + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> sequence_mask = Tensor(np.random.rand(256, )).astype(mstype.float32) + >>> aatype = Tensor(np.random.randint(0, 21, (256,) )).astype(mstype.int32) + >>> sin_cos_true_chi = Tensor(np.random.rand(256, 4, 2)).astype(mstype.float32) + >>> torsion_angle_mask = Tensor(np.random.rand(256, 4)).astype(mstype.float32) + >>> sin_cos_pred_chi = Tensor(np.random.rand(256, 14)).astype(mstype.float32) + >>> sin_cos_unnormalized_pred = Tensor(np.random.rand(8, 256, 7, 2)).astype(mstype.float32) + >>> chi_weight = 0.1 + >>> angle_norm_weight = 0.2 + >>> chi_pi_periodic = Tensor(residue_constants.chi_pi_periodic).astype(mstype.float32) + >>> chi_loss = supervised_chi(sequence_mask, aatype, sin_cos_true_chi, torsion_angle_mask, sin_cos_pred_chi, + ... sin_cos_unnormalized_pred, chi_weight, angle_norm_weight, chi_pi_periodic) + >>> print(chi_loss) + 0.061829045 + """ + eps = 1e-6 + + num_res = sequence_mask.shape[0] + chi_mask = torsion_angle_mask + pred_angles = mnp.reshape(sin_cos_pred_chi, [-1, num_res, 7, 2]) + pred_angles = pred_angles[:, :, 3:] + + residue_type_one_hot = nn.OneHot(depth=21)(aatype)[None] + chi_pi_periodic = mnp.matmul(residue_type_one_hot, chi_pi_periodic) + + # This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic + shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] + sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi + + sq_chi_error = mnp.sum(mnp.square(sin_cos_true_chi - pred_angles), -1) + sq_chi_error_shifted = mnp.sum(mnp.square(sin_cos_true_chi_shifted - pred_angles), -1) + sq_chi_error = mnp.minimum(sq_chi_error, sq_chi_error_shifted) + + sq_chi_loss = P.ReduceSum()(chi_mask[None] * sq_chi_error, (0, 1, 2)) / \ + (P.ReduceSum()(chi_mask[None], (0, 1, 2)) * 8 + 1e-10) + + loss = chi_weight * sq_chi_loss + unnormed_angles = mnp.reshape(sin_cos_unnormalized_pred[-1], [-1, num_res, 7, 2]) + angle_norm = mnp.sqrt(mnp.sum(mnp.square(unnormed_angles), axis=-1) + eps) + norm_error = mnp.abs(angle_norm - 1.) + angle_norm_loss = P.ReduceSum()(sequence_mask[None, :, None] * norm_error, (0, 1, 2)) / \ + (P.ReduceSum()(sequence_mask[None, :, None].astype(ms.float32), (0, 1, 2)) * 7 + 1e-10) + + loss += angle_norm_weight * angle_norm_loss + return loss + +def local_distance_difference_test(predicted_points, true_points, true_points_mask, cutoff=15, per_residue=False): + r""" + Compute true and predicted distance matrices for :math:`C\alpha`. + First calculate the distance matrix of true and predicted :math:`C\alpha` atoms + :math:`D = (((x[None,:] - x[:,None])^2).sum(-1))^{0.5}` + then compute the rate that difference is smaller than fixed value: + :math:`lddt = (rate(abs(D_{true} - D_{pred}) < 0.5) + rate(abs(D_{true} - D_{pred}) < 1.0) + + rate(abs(D_{true} - D_{pred}) < 2.0) + rate(abs(D_{true} - D_{pred}) < 4.0))/4` + `Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca" + `_. + + Args: + predicted_points (Tensor): The prediction Ca atoms position tensor of shape + :math:`(1, N_{res}, 3)` with :math:`N_{res}` the number of residues in protein. + true_points (Tensor): The ground truth Ca atoms position tensor of shape + :math:`(1, N_{res}, 3)` + true_points_mask (Tensor): The binary mask for predicted_points of shape + :math:`(1, N_{res}, 1)` + cutoff (float): The cutoff value for lddt to stop gradient, Default: 15. + per_residue (bool): The indicator if local distance difference is averaged, + set True to return local distance difference per residue. Default: False. + + Returns: + - **score** (Tensor) - Local distance difference score, the shape is :math:`(1,)` + if per_residue set False, :math:`(1, N_{res})` otherwise. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> np.random.seed(0) + >>> from mindsponge.metrics import local_distance_difference_test + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> predicted_points = Tensor(np.random.rand(1, 256, 3)).astype(mstype.float32) + >>> true_points = Tensor(np.random.rand(1, 256, 3)).astype(mstype.float32) + >>> true_points_mask = Tensor(np.random.rand(1, 256, 1)).astype(mstype.float32) + >>> lddt = local_distance_difference_test(predicted_points, true_points, true_points_mask, + ... cutoff=15, per_residue=False) + >>> print(lddt) + [0.9554313] + """ + dmat_true = mnp.sqrt(1e-10 + mnp.sum((true_points[:, :, None] - true_points[:, None, :]) ** 2, axis=-1)) + + dmat_predicted = mnp.sqrt(1e-10 + mnp.sum((predicted_points[:, :, None] - predicted_points[:, None, :]) ** 2, + axis=-1)) + + dists_to_score = ((dmat_true < cutoff).astype(mnp.float32) * true_points_mask * + mnp.transpose(true_points_mask, [0, 2, 1]) * + (1. - mnp.eye(dmat_true.shape[1])) # Exclude self-interaction. + ) + + # Shift unscored distances to be far away. + dist_l1 = mnp.abs(dmat_true - dmat_predicted) + + # True lDDT uses a number of fixed bins. + # We ignore the physical plausibility correction to lDDT, though. + score = 0.25 * ((dist_l1 < 0.5).astype(mnp.float32) + + (dist_l1 < 1.0).astype(mnp.float32) + + (dist_l1 < 2.0).astype(mnp.float32) + + (dist_l1 < 4.0).astype(mnp.float32)) + + # Normalize over the appropriate axes. + reduce_axes = (-1,) if per_residue else (-2, -1) + norm = 1. / (1e-10 + mnp.sum(dists_to_score, axis=reduce_axes)) + score = norm * (1e-10 + mnp.sum(dists_to_score * score, axis=reduce_axes)) + return score + diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/ops/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d089c1b045e6382e300616da5ecf5b98206ac3a --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ops""" + +from . import cpu +from .cpu import * + +__all__ = [] +__all__.extend(cpu.__all__) +__all__.sort() diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/ops/cpu/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/cpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ab5aec237aace1c143dd63f0f0c50a672ca599 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/cpu/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""NeighborListOP""" + +from .neighborlistop import NeighborListOP +__all__ = ['NeighborListOP'] + +__all__.sort() diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/ops/cpu/neighborlistop.py b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/cpu/neighborlistop.py new file mode 100644 index 0000000000000000000000000000000000000000..802a9c529e4a15ad60df2ec699ef7d805d561b25 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/cpu/neighborlistop.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""neighborlistop""" + +import os +import mindspore.common.dtype as mstype +import mindspore.ops as ops +from mindspore.ops import DataType, CustomRegOp + +put_atom_into_bucket_add = CustomRegOp() \ + .input(0, "x0") \ + .input(1, "x1") \ + .input(2, "x2") \ + .output(0, "y0") \ + .output(1, "y1") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, \ + DataType.None_None, DataType.None_None) \ + .target("CPU") \ + .get_op_info() + +find_atom_neighbors_add = CustomRegOp() \ + .input(0, "x0") \ + .input(1, "x1") \ + .input(2, "x2") \ + .input(3, "x3") \ + .input(4, "x4") \ + .input(5, "x5") \ + .input(6, "x6") \ + .input(7, "x7") \ + .input(8, "x8") \ + .output(0, "y0") \ + .output(1, "y1") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, \ + DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, \ + DataType.None_None, DataType.None_None, DataType.None_None) \ + .target("CPU") \ + .get_op_info() + +delete_excluded_atoms_add = CustomRegOp() \ + .input(0, "x0") \ + .input(1, "x1") \ + .input(2, "x2") \ + .input(3, "x3") \ + .input(4, "x4") \ + .output(5, "y0") \ + .output(6, "y1") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, \ + DataType.None_None, DataType.None_None, DataType.None_None, \ + DataType.None_None) \ + .target("CPU") \ + .get_op_info() + +class NeighborListOP(): + """NeighborListOP""" + def __init__(self): + lib_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../libs/libneighborlist.so")) + self.put_atom_path = lib_path + ":PutAtomIntoBucket" + self.find_atom_path = lib_path + ":FindAtomNeighbors" + self.delete_atom_path = lib_path + ":DeleteExcludedAtoms" + + def register(self, atom_numbers, grid_numbers, max_atom_in_grid_numbers, max_neighbor_numbers): + """Register the neighbor list operator.""" + put_atom_into_bucket_op = ops.Custom(self.put_atom_path, \ + out_shape=(([grid_numbers, max_atom_in_grid_numbers], [grid_numbers,])), \ + out_dtype=(mstype.int32, mstype.int32), func_type="aot", reg_info=put_atom_into_bucket_add) + find_atom_neighbors_op = ops.Custom(self.find_atom_path, \ + out_shape=(([atom_numbers,], [atom_numbers, max_neighbor_numbers])), \ + out_dtype=(mstype.int32, mstype.int32), func_type="aot", reg_info=find_atom_neighbors_add) + delete_excluded_atoms_op = ops.Custom(self.delete_atom_path, \ + out_shape=(([atom_numbers,], [atom_numbers, max_neighbor_numbers])), \ + out_dtype=(mstype.int32, mstype.int32), func_type="aot", reg_info=delete_excluded_atoms_add) + return put_atom_into_bucket_op, find_atom_neighbors_op, delete_excluded_atoms_op diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/ops/gpu/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/gpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ab5aec237aace1c143dd63f0f0c50a672ca599 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/gpu/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""NeighborListOP""" + +from .neighborlistop import NeighborListOP +__all__ = ['NeighborListOP'] + +__all__.sort() diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/ops/gpu/neighborlistop.py b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/gpu/neighborlistop.py new file mode 100644 index 0000000000000000000000000000000000000000..a2dda8023e1110acbfc3b38839199806ce97ca55 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/ops/gpu/neighborlistop.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""neighborlistop""" + +import os +import mindspore.common.dtype as mstype +import mindspore.ops as ops +from mindspore.ops import DataType, CustomRegOp + +put_atom_into_bucket_add = CustomRegOp() \ + .input(0, "x0") \ + .input(1, "x1") \ + .input(2, "x2") \ + .output(0, "y0") \ + .output(1, "y1") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I32_Default) \ + .target("GPU") \ + .get_op_info() + +find_atom_neighbors_add = CustomRegOp() \ + .input(0, "x0") \ + .input(1, "x1") \ + .input(2, "x2") \ + .input(3, "x3") \ + .input(4, "x4") \ + .input(5, "x5") \ + .input(6, "x6") \ + .input(7, "x7") \ + .input(8, "x8") \ + .output(0, "y0") \ + .output(1, "y1") \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, \ + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .target("GPU") \ + .get_op_info() + +delete_excluded_atoms_add = CustomRegOp() \ + .input(0, "x0") \ + .input(1, "x1") \ + .input(2, "x2") \ + .input(3, "x3") \ + .input(4, "x4") \ + .output(0, "y0") \ + .output(1, "y1") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default) \ + .target("GPU") \ + .get_op_info() + +class NeighborListOP(): + """NeighborListOP""" + def __init__(self): + lib_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../libs/libneighborlist.so")) + self.put_atom_path = lib_path + ":PutAtomIntoBucket" + self.find_atom_path = lib_path + ":FindAtomNeighbors" + self.delete_atom_path = lib_path + ":DeleteExcludedAtoms" + + def register(self, atom_numbers, grid_numbers, max_atom_in_grid_numbers, max_neighbor_numbers): + """Register the neighbor list operator.""" + put_atom_into_bucket_op = ops.Custom(self.put_atom_path, \ + out_shape=(([grid_numbers, max_atom_in_grid_numbers], [grid_numbers,])), \ + out_dtype=(mstype.int32, mstype.int32), func_type="aot", reg_info=put_atom_into_bucket_add) + find_atom_neighbors_op = ops.Custom(self.find_atom_path, \ + out_shape=(([atom_numbers,], [atom_numbers, max_neighbor_numbers])), \ + out_dtype=(mstype.int32, mstype.int32), func_type="aot", reg_info=find_atom_neighbors_add) + delete_excluded_atoms_op = ops.Custom(self.delete_atom_path, \ + out_shape=(([atom_numbers,], [atom_numbers, max_neighbor_numbers])), \ + out_dtype=(mstype.int32, mstype.int32), func_type="aot", reg_info=delete_excluded_atoms_add) + return put_atom_into_bucket_op, find_atom_neighbors_op, delete_excluded_atoms_op diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce22e161faeac24587cea722e31beb885982c344 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Space updater""" + +from .updater import Updater +from .dynamics import DynamicUpdater +from .steepest import SteepestDescent + +__all__ = ['Updater', 'DynamicUpdater', 'SteepestDescent'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/dynamics.py b/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/dynamics.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbf46de5805395fed3fd7b3ef162c6d66b60a78 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/dynamics.py @@ -0,0 +1,141 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Space updater""" + +from mindspore import Tensor +from mindspore.nn.optim.optimizer import opt_init_args_register + +from . import Updater +from ..system import Molecule +from ..control.controller import Controller +from ..control.integrator import Integrator + + +class DynamicUpdater(Updater): + r""" + A updater for molecular dynamics (MD) simulation. + + Args: + system (Molecule): Simulation system. + integrator (Integrator): MD integrator. + thermostat (Controller): Thermostat for temperature coupling. Default: None + barostat (Controller): Barostat for pressure coupling. Default: None + constraint (Controller): Constraint for bond. Default: None + controller (Controller): Other controllers. Default: None + time_step (float): Time step. Default: 1e-3 + velocity (Tensor): Tensor of shape (B, A, D). Data type is float. + Default: None + weight_decay (float): A value for the weight decay. Default: 0.0 + loss_scale (float): A value for the loss scale. Default: 1.0 + + Returns: + bool, update the parameters of system. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + @opt_init_args_register + def __init__(self, + system: Molecule, + integrator: Integrator, + thermostat: Controller = None, + barostat: Controller = None, + constraint: Controller = None, + controller: Controller = None, + time_step: float = 1e-3, + velocity: Tensor = None, + weight_decay: float = 0.0, + loss_scale: float = 1.0, + ): + + super().__init__( + system=system, + controller=controller, + time_step=time_step, + velocity=velocity, + weight_decay=weight_decay, + loss_scale=loss_scale, + ) + + self.integrator = integrator + + if thermostat is not None: + self.integrator.set_thermostat(thermostat) + self.thermostat = self.integrator.thermostat + + if barostat is not None: + if self.pbc_box is None: + raise ValueError('Barostat cannot be used for the system without periodic boundary condition.') + self.integrator.set_barostat(barostat) + self.barostat = self.integrator.barostat + + if constraint is not None: + self.integrator.set_constraint(constraint) + self.constraint = self.integrator.constraint + + self.integrator.set_time_step(self.time_step) + self.integrator.set_degrees_of_freedom(self.degrees_of_freedom) + + def construct(self, gradients: tuple, loss: Tensor = None): + """update the parameters of system""" + gradients = self.decay_weight(gradients) + gradients = self.scale_grad(gradients) + + coordinate = self.coordinate + velocity = self.velocity + force = -gradients[0] + energy = loss + kinetics = self.kinetics + pbc_box = self.pbc_box + virial = None + if self.pbc_box is not None: + virial = self.get_virial(gradients[1], pbc_box) + + step = self.identity(self.step) + coordinate, velocity, force, energy, kinetics, virial, pbc_box = \ + self.integrator(coordinate, velocity, force, energy, kinetics, virial, pbc_box, step) + + if self.controller is not None: + for i in range(self.num_controller): + coordinate, velocity, force, energy, kinetics, virial, pbc_box = \ + self.controller[i](coordinate, velocity, force, energy, kinetics, virial, pbc_box, step) + + temperature = self.get_temperature(kinetics) + pressure = self.get_pressure(kinetics, virial, pbc_box) + + success = True + success = self.update_coordinate(coordinate, success) + success = self.update_velocity(velocity, success) + success = self.update_pbc_box(pbc_box, success) + success = self.update_kinetics(kinetics, success) + success = self.update_temperature(temperature, success) + success = self.update_virial(virial, success) + success = self.update_pressure(pressure, success) + + return self.next_step(success) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/steepest.py b/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/steepest.py new file mode 100644 index 0000000000000000000000000000000000000000..173e117713e940c0f8f7f08e6a7ede8bc621c7c2 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/steepest.py @@ -0,0 +1,44 @@ +""" +Optimizer used to get the minimum value of a given function. +""" +import mindspore as ms +from mindspore import nn, Parameter, Tensor +from mindspore import numpy as msnp + + +class SteepestDescent(nn.Optimizer): + """ + The steepest descent (gradient descent) optimizer with growing learning rate. + + Args: + crd(tuple): Usually a tuple of parameters is given and the first element is coordinates. + learning_rate(float): A factor of each optimize step size. + factor(float): A growing factor of learning rate. + nonh_mask(Tensor): The mask of atoms which are not Hydrogen. + max_shift(float): The max step size each atom can move. + + Returns: + float, the first element of parameters. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, crd, learning_rate=1e-03, factor=1.001, nonh_mask=None, max_shift=1.0): + super(SteepestDescent, self).__init__(learning_rate, crd) + self.crd = crd[0] + self.learning_rate = Parameter(Tensor(learning_rate, ms.float32)) + self.factor = Parameter(Tensor(factor, ms.float32)) + if nonh_mask is not None: + self.nonh_mask = nonh_mask + else: + self.nonh_mask = msnp.ones((1, self.crd.shape[-2], 1)) + self.max_shift = Parameter(Tensor(max_shift, ms.float32)) + + def construct(self, gradients): + shift = self.learning_rate*gradients[0]*self.nonh_mask + shift = msnp.where(shift > self.max_shift, self.max_shift, shift) + shift = msnp.where(shift < -self.max_shift, -self.max_shift, shift) + self.crd -= shift + self.learning_rate *= self.factor + return self.crd diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/updater.py b/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/updater.py new file mode 100644 index 0000000000000000000000000000000000000000..7a33b6ae3b41f6070a6f1432d899421354ae400f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/optimizer/updater.py @@ -0,0 +1,407 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Space updater""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor, Parameter +from mindspore import ops +from mindspore.ops import functional as F +from mindspore.nn import CellList +from mindspore.nn import Optimizer +from mindspore.nn.optim.optimizer import opt_init_args_register +from mindspore.common.initializer import initializer + +from ..system import Molecule +from ..control import Controller +from ..function import functions as func + + +class Updater(Optimizer): + r""" + Optimizer to update parameters of space (coordinates and PBC box). + + Args: + system (Molecule): Simulation system. + controller (Controller): Controller. Default: None + time_step (float): Time step. Default: 1e-3 + velocity (Tensor): Tensor of shape (B, A, D). Data type is float. + Default: None + weight_decay (float): A value for the weight decay. Default: 0.0 + loss_scale (float): A value for the loss scale. Default: 1.0 + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + @opt_init_args_register + def __init__(self, + system: Molecule, + controller: Controller = None, + time_step: float = 1e-3, + velocity: Tensor = None, + weight_decay: float = 0.0, + loss_scale: float = 1.0, + ): + + super().__init__( + learning_rate=time_step, + parameters=system.space_parameters(), + weight_decay=weight_decay, + loss_scale=loss_scale, + ) + + self.time_step = Tensor(time_step, ms.float32) + + self.system = system + self.coordinate = self.system.coordinate + self.pbc_box = self.system.pbc_box + + # (B,A) + self.atom_mass = self.system.atom_mass + self.inv_mass = self.system.inv_mass + # (B,A,1) + self._atom_mass = F.expand_dims(self.atom_mass, -1) + self._inv_mass = F.expand_dims(self.inv_mass, -1) + + self.num_walker = system.num_walker + self.num_atoms = system.num_atoms + self.dimension = system.dimension + + self.units = self.system.units + self.boltzmann = self.units.boltzmann + self.kinetic_unit_scale = self.units.kinetic_ref + self.press_unit_scale = self.units.pressure_ref + + if velocity is None: + self.velocity = Parameter(msnp.zeros_like(self.coordinate), name='velocity') + else: + velocity = Tensor(velocity) + if velocity.ndim == 2: + velocity = F.expand_dims(velocity, 0) + if velocity.shape != self.coordinate.shape: + raise ValueError('The shape of velocity '+str(velocity.shape) + + 'must be equal to the shape of coordinate '+str(self.coordinate.shape)+'!') + self.velocity = Parameter(velocity, name='velocity') + + self.num_constraints = 0 + self.num_controller = 0 + if controller is None: + self.controller = None + else: + if isinstance(controller, Controller): + self.num_controller = 1 + controller = [controller] + elif isinstance(controller, list): + self.num_controller = len(controller) + else: + raise TypeError('The type of "controller" must be Controller or list but got: ' + +str(type(controller))) + + self.controller = CellList(controller) + for i in range(self.num_controller): + self.num_constraints += self.controller[i].num_constraints + + self.degrees_of_freedom = system.degrees_of_freedom - self.num_constraints + + self.identity = ops.Identity() + + self.kinetics = None + self.temperature = None + if self.velocity is not None: + kinetics = self.get_kinetics(self.velocity) + temperature = self.get_temperature(kinetics) + # (B,D) + self.kinetics = Parameter(kinetics, name="kinetics") + # (B) + self.temperature = Parameter(temperature, name="temperature") + + self.virial = None + self.pressure = None + if self.pbc_box is not None: + # (B,D) + self.virial = Parameter(initializer( + 'zeros', (self.num_walker, self.dimension), ms.float32), name="virial") + self.pressure = Parameter(initializer( + 'zeros', (self.num_walker, self.dimension), ms.float32), name="pressure") + + self.step = Parameter(Tensor(0, ms.int32), name='updater_step') + + if controller is not None: + for i in range(self.num_controller): + self.controller[i].set_time_step(self.time_step) + self.controller[i].set_degrees_of_freedom(self.degrees_of_freedom) + + def set_step(self, step: int = 0): + """ + set time step. + + Args: + step (int): Time steps. + """ + step = Tensor(step, ms.int32) + F.depend(True, F.assign(self.step, step)) + return self + + def update_coordinate(self, coordinate: Tensor, success: bool = True) -> bool: + """ + update the parameters of coordinate. + + Args: + coordinate (Tensor): Tensor of coordinates. + success (bool): Whether successfully update the parameters. + + Returns: + bool. + """ + success = F.depend(success, F.assign(self.coordinate, coordinate)) + return success + + def update_pbc_box(self, pbc_box: Tensor, success: bool = True) -> bool: + """ + update the parameters of PBC box. + + Args: + pbc_box (Tensor): Tensor of PBC box. + success (bool): Whether successfully update the parameters. + + Returns: + bool. + """ + if self.pbc_box is None: + return success + return F.depend(success, F.assign(self.pbc_box, pbc_box)) + + def update_velocity(self, velocity: Tensor, success: bool = True) -> bool: + """ + update the parameters of velocity. + + Args: + velocity (Tensor): Tensor of velocity. + success (bool): Whether successfully update the parameters. + + Returns: + bool. + """ + return F.depend(success, F.assign(self.velocity, velocity)) + + def update_kinetics(self, kinetics: Tensor, success: bool = True) -> bool: + """ + update the parameters of kinects. + + Args: + kinetics (Tensor): Tensor of kinetics. + success (bool): Whether successfully update the parameters. + + Returns: + bool. + """ + if self.kinetics is None: + return success + return F.depend(success, F.assign(self.kinetics, kinetics)) + + def update_temperature(self, temperature: Tensor, success: bool = True) -> bool: + """ + update the parameters of temperature. + + Args: + temperature (Tensor): Tensor of temperature. + success (bool): Whether successfully update the parameters. + + Returns: + bool. + """ + if self.temperature is None: + return success + return F.depend(success, F.assign(self.temperature, temperature)) + + def update_virial(self, virial: Tensor, success: bool = True) -> bool: + """ + update the parameters of virial. + + Args: + virial (Tensor): Tensor of virial. + success (bool): Whether successfully update the parameters. + + Returns: + bool. + """ + if self.pbc_box is None: + return success + return F.depend(success, F.assign(self.virial, virial)) + + def update_pressure(self, pressure: Tensor, success: bool = True) -> bool: + """ + update the parameters of pressure. + + Args: + pressure (Tensor): Tensor of pressure. + success (bool): Whether successfully update the parameters. + + Returns: + bool. + """ + if self.pbc_box is None: + return success + return F.depend(success, F.assign(self.pressure, pressure)) + + def get_velocity(self) -> Tensor: + """ + get velocity. + + Returns: + Tensor, a Tensor of velocity. + """ + if self.velocity is None: + return None + return self.identity(self.velocity) + + def get_kinetics(self, velocity: Tensor) -> Tensor: + """ + get kinectics. + + Args: + velocity (Tensor): Tensor of velocity. + + Returns: + Tensor, a Tensor of kinetics. + """ + # (B,A,D) + kinetics = 0.5 * self._atom_mass * velocity**2 + # (B,D) <- (B,A,D) + kinetics = F.reduce_sum(kinetics, -2) + return kinetics * self.kinetic_unit_scale + + def get_temperature(self, kinetics: Tensor = None) -> Tensor: + """ + get temperature. + + Args: + kinetics (Tensor): Tensor of kinetics. + + Returns: + Tensor, a Tensor of temperature. + """ + # (B) <- (B,D) + kinetics = F.reduce_sum(kinetics, -1) + return 2 * kinetics / self.degrees_of_freedom / self.boltzmann + + def get_pressure(self, kinetics: Tensor, virial: Tensor, pbc_box: Tensor) -> Tensor: + """ + get pressure. + + Args: + kinetics (Tensor): Tensor of kinetics. + virial (Tensor): Tensor of virial. + pbc_box (Tensor): Tensor of PBC box. + + Returns: + Tensor, a Tensor of pressure. + """ + if self.pbc_box is None: + return None + # (B,D) = ((B,D) - (B, D)) / (B,1) + volume = func.keepdim_prod(pbc_box, -1) + pressure = 2 * (kinetics - virial) / volume + return pressure * self.press_unit_scale + + def get_virial(self, pbc_grad, pbc_box): + """ + get virial. + + Args: + pbc_grad (Tensor): Tensor of the grad of PBC box. + pbc_box (Tensor): Tensor of PBC box. + + Returns: + Tensor, a Tensor of virial. + """ + # (B,D) + return 0.5 * pbc_grad * pbc_box + + def get_dt(self): + """ + get the learning rate of current step. + + Returns: + float, the learning rate of current step. + """ + return self.get_lr() + + def next_step(self, success: bool = True) -> bool: + """ + finish the current optimization step and move to next step. + + Args: + success (bool): Whether move to next step. + + Returns: + bool. + """ + return F.depend(success, F.assign(self.step, self.step+1)) + + def construct(self, gradients: tuple, loss: Tensor = None): + """ + update the parameters of system. + + Returns: + bool. + """ + + gradients = self.decay_weight(gradients) + gradients = self.scale_grad(gradients) + + coordinate = self.coordinate + velocity = self.velocity + force = -gradients[0] + energy = loss + kinetics = self.kinetics + pbc_box = self.pbc_box + virial = None + if self.pbc_box is not None: + virial = self.get_virial(gradients[1], pbc_box) + + step = self.identity(self.step) + if self.controller is not None: + for i in range(self.num_controller): + coordinate, velocity, force, energy, kinetics, virial, pbc_box = \ + self.controller[i](coordinate, velocity, force, energy, kinetics, virial, pbc_box, step) + + temperature = self.get_temperature(kinetics) + pressure = self.get_pressure(kinetics, virial, pbc_box) + + success = True + success = self.update_coordinate(coordinate, success) + success = self.update_velocity(velocity, success) + success = self.update_pbc_box(pbc_box, success) + success = self.update_kinetics(kinetics, success) + success = self.update_temperature(temperature, success) + success = self.update_virial(virial, success) + success = self.update_pressure(pressure, success) + + return self.next_step(success) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/partition/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d884d4e656213563440bc4aa8617bd3e1885b32 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Neighbour list +""" + +from .fullconnect import FullConnectNeighbours +from .distance import DistanceNeighbours +from .grids import GridNeighbours +from .neighbourlist import NeighbourList + +__all__ = ['FullConnectNeighbours', 'DistanceNeighbours', 'GridNeighbours', 'NeighbourList'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/partition/distance.py b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/distance.py new file mode 100644 index 0000000000000000000000000000000000000000..16d90a704838612b990d1bd84d1eb69581810cc6 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/distance.py @@ -0,0 +1,241 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Use the distances between atoms to calculate neighbour list +""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore.nn import Cell +from mindspore import ops +from mindspore.ops import functional as F + +from ..function.functions import get_integer +from ..function.functions import calc_distance +from ..function.functions import calc_distance_with_pbc +from ..function.functions import calc_distance_without_pbc + + +class DistanceNeighbours(Cell): + r""" + Neighbour list calculated by distance. + + Args: + cutoff (float): Cutoff distance. + num_neighbours (int): Number of neighbours. If input "None", this value will be calculated by + the ratio of the number of neighbouring grids to the total number of grids. + Default: None + atom_mask (Tensor): Tensor of shape (B, A). Data type is bool\_. + Mask of atoms in the system. Default: None + exclude_index (Tensor): Tensor of shape (B, A, Ex). Data type is int32. + Index of neighbour atoms which could be excluded from the neighbour list. + Default: None + use_pbc (bool): Whether to use periodic boundary condition. Default: None + cutoff_scale (float): Factor to scale the cutoff distance. Default: 1.2 + large_dis (float): A large number of distance to fill the default atoms. Default: 1e4 + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + B: Number of simulation walker. + A: Number of atoms in system. + Ex: Maximum number of excluded neighbour atoms. + """ + + def __init__(self, + cutoff: float, + num_neighbours: int = None, + atom_mask: Tensor = None, + exclude_index: Tensor = None, + use_pbc: bool = None, + cutoff_scale: float = 1.2, + large_dis: float = 1e4, + ): + + super().__init__() + + self.cutoff = Tensor(cutoff, ms.float32) + self.cutoff_scale = Tensor(cutoff_scale, ms.float32) + self.scaled_cutoff = self.cutoff * self.cutoff_scale + + self.num_neighbours = get_integer(num_neighbours) + + self.large_dis = Tensor(large_dis, ms.float32) + + self.emtpy_atom_shift = 0 + self.atom_mask = None + self.has_empty_atom = False + if atom_mask is not None: + # (B,A) + self.atom_mask = Tensor(atom_mask, ms.bool_) + if self.atom_mask.ndim == 1: + self.atom_mask = F.expand_dims(self.atom_mask, 0) + + self.has_empty_atom = F.logical_not(self.atom_mask.all()) + if self.has_empty_atom: + emtpy_atom_mask = F.logical_not(self.atom_mask) + # (B,1,A) + self.emtpy_atom_shift = F.expand_dims(emtpy_atom_mask, -2) * self.large_dis + + self.exclude_index = None + if exclude_index is not None: + # (B,A,E) + self.exclude_index = Tensor(exclude_index, ms.int32) + if self.exclude_index.ndim == 2: + self.exclude_index = F.expand_dims(self.exclude_index, 0) + + if use_pbc is None: + self.get_distances = calc_distance + else: + if use_pbc: + self.get_distances = calc_distance_with_pbc + else: + self.get_distances = calc_distance_without_pbc + + self.sort = ops.Sort(-1) + self.reduce_all = ops.ReduceAll() + + def set_exclude_index(self, exclude_index: Tensor): + # (B,A,Ex) + self.exclude_index = Tensor(exclude_index, ms.int32) + if self.exclude_index.ndim == 2: + self.exclude_index = F.expand_dims(self.exclude_index, 0) + return self + + def print_info(self): + return self + + def check_neighbours_number(self, neighbour_mask: Tensor): + """ + check number of neighbours in neighbour list. + + Args: + neighbour_mask (Tensor): The neighbour list mask. + """ + max_neighbours = F.cast(msnp.max(F.cast(msnp.sum(neighbour_mask, -1), ms.float32)), ms.int32) + if max_neighbours > self.num_neighbours: + print( + '================================================================================') + print( + 'Warning! Warning! Warning! Warning! Warning! Warning! Warning! Warning! Warning!') + print( + '--------------------------------------------------------------------------------') + print('The max number of neighbour atoms is larger than that in neighbour list!') + print('The max number of neighbour atoms:') + print(max_neighbours) + print('The number of neighbour atoms in neighbour list:') + print(self.num_neighbours) + print('Please increase the value of grid_num_scale or num_neighbours!') + print( + '================================================================================') + return self + + def construct(self, + coordinate: Tensor, + pbc_box: Tensor = None, + atom_mask: Tensor = None, + exclude_index: Tensor = None + ): + r""" + Calculate distances and neighbours. + + Args: + coordinate (Tensor): Tensor of (B, A, D). Data type is float. + Position coordinates of atoms. + pbc_box (Tensor): Tensor of (B, D). Data type is bool. + Periodic boundary condition box. Default: None + atom_mask (Tensor): Tensor of (B, A). Data type is bool. + Atomic mask. + exclude_index (Tensor): Tensor of (B, A, Ex). Data type is int. + Index of the atoms that should be excluded from the neighbour list. + Default: None + + Returns: + - distances (Tensor), Tensor of (B, A, N). Data type is float. + - neighbours (Tensor), Tensor of (B, A, N). Data type is int. + - neighbour_mask (Tensor), Tensor of (B, A, N). Data type is bool. + + Symbols: + B: Batch size. + A: Number of atoms in system. + N: Number of neighbour atoms. + D: Dimension of position coordinates. + Ex: Maximum number of excluded neighbour atoms. + """ + + # A + num_atoms = coordinate.shape[-2] + # (B,A,A) <- (B,A,1,3) - (B,1,A,3) + distances = self.get_distances(F.expand_dims( + coordinate, -2), F.expand_dims(coordinate, -3), pbc_box).squeeze(-1) + + if atom_mask is None: + atom_mask = self.atom_mask + if self.has_empty_atom: + # (B,A,A) + (B,1,A) + distances += self.emtpy_atom_shift + else: + if not atom_mask.all(): + emtpy_atom_mask = F.logical_not(atom_mask) + # (B,1,A) + emtpy_atom_shift = F.expand_dims( + emtpy_atom_mask, -2) * self.large_dis + distances += emtpy_atom_shift + + distances, neighbours = self.sort(distances) + # (B,A) + neighbour_mask = distances < self.scaled_cutoff + + if self.num_neighbours is None: + num_neighbours = num_atoms - 1 + else: + num_neighbours = self.num_neighbours + + distances = distances[..., 1:num_neighbours+1] + neighbours = neighbours[..., 1:num_neighbours+1] + neighbour_mask = neighbour_mask[..., 1:num_neighbours+1] + if self.num_neighbours is not None: + self.check_neighbours_number(neighbour_mask) + + if exclude_index is None: + exclude_index = self.exclude_index + if exclude_index is not None: + # (B,A,n,E) <- (B,A,n,1) != (B,A,1,E) + exc_mask = F.expand_dims( + neighbours, -1) != F.expand_dims(exclude_index, -2) + # (B,A,n) + exc_mask = self.reduce_all(exc_mask, -1) + neighbour_mask = F.logical_and(neighbour_mask, exc_mask) + + if atom_mask is not None: + # (B,A,n) <- (B,A,n) && (B,A,1) + neighbour_mask = F.logical_and( + neighbour_mask, F.expand_dims(atom_mask, -1)) + + # (B,A,n) + no_idx = msnp.arange(num_atoms).reshape(1, -1, 1) + neighbours = msnp.where(neighbour_mask, neighbours, no_idx) + + return distances, neighbours, neighbour_mask diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/partition/fullconnect.py b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/fullconnect.py new file mode 100644 index 0000000000000000000000000000000000000000..d2040660f71a346a2d8cc361efaa7504a88aced5 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/fullconnect.py @@ -0,0 +1,136 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Full connected neighbour list +""" + +import mindspore as ms +from mindspore import numpy as msnp +from mindspore import Tensor +from mindspore import ops +from mindspore.nn import Cell +from mindspore.ops import functional as F + + +class FullConnectNeighbours(Cell): + r""" + Full connected neighbour list. + + Args: + num_atoms (int): Number of atoms. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, num_atoms: int): + super().__init__() + self.num_atoms = num_atoms + self.num_neighbours = num_atoms - 1 + + # neighbours for no connection (A*N) + # (A,1) + no_idx = msnp.arange(self.num_atoms).reshape(-1, 1) + + # (N) + nrange = msnp.arange(self.num_neighbours) + + # neighbours for full connection (A,N) + # [[1,2,3,...,N], + # [0,2,3,...,N], + # [0,1,3,....N], + # ............., + # [0,1,2,...,N-1]] + fc_idx = nrange + F.cast(no_idx <= nrange, ms.int32) + no_idx = msnp.broadcast_to( + no_idx, (self.num_atoms, self.num_neighbours)) + idx_mask = fc_idx > no_idx + + # (1,A,N) + self.fc_idx = F.expand_dims(fc_idx, 0) + self.no_idx = F.expand_dims(no_idx, 0) + self.idx_mask = F.expand_dims(idx_mask, 0) + + self.shape = (self.num_atoms, self.num_neighbours) + self.fc_mask = msnp.broadcast_to(Tensor(True), (1,)+self.shape) + + self.reduce_all = ops.ReduceAll() + + def set_exclude_index(self, _exclude_index: Tensor): + """ + Dummy. + + Args: + _exclude_index (Tensor): Tensor of exclude indexes. + """ + # pylint: disable=invalid-name + return self + + def print_info(self): + """print information""" + return self + + def construct(self, atom_mask: Tensor = None, exclude_index: Tensor = None): + r""" + Calculate the full connected neighbour list. + + Args: + atom_mask (Tensor): Tensor of shape (B, A). Data type is bool. + exclude_index (Tensor): Tensor of shape (B, A, Ex). Data type is int. + + Returns: + - neighbours (Tensor), Tensor of shape (B, A, N). Data type is int. + - neighbour_mask (Tensor), Tensor of shape (B, A, N). Data type is bool. + + Symbols: + B: Batch size. + A: Number of atoms in system. + N: Number of neighbour atoms. + D: Dimension of position coordinates. + Ex: Maximum number of excluded neighbour atoms. + + """ + if atom_mask is None: + neighbours = self.fc_idx + neighbour_mask = self.fc_mask + else: + + # (B,1,N) + mask0 = F.expand_dims(atom_mask[:, :-1], -2) + mask1 = F.expand_dims(atom_mask[:, 1:], -2) + + # (B,A,N) + neighbour_mask = msnp.where(self.idx_mask, mask1, mask0) + neighbour_mask = F.logical_and(F.expand_dims(atom_mask, -1), neighbour_mask) + neighbours = msnp.where(neighbour_mask, self.fc_idx, self.no_idx) + + if exclude_index is not None: + # (B,A,N,E) <- (B,A,N,1) vs (B,A,1,E) + exc_mask = F.expand_dims( + neighbours, -1) != F.expand_dims(exclude_index, -2) + # (B,A,N) + exc_mask = self.reduce_all(exc_mask, -1) + neighbour_mask = F.logical_and(neighbour_mask, exc_mask) + neighbours = msnp.where(neighbour_mask, neighbours, self.no_idx) + + return neighbours, neighbour_mask diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/partition/grids.py b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/grids.py new file mode 100644 index 0000000000000000000000000000000000000000..605f1625a61296d5046429949df876c9662afde5 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/grids.py @@ -0,0 +1,477 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Use grids to calculate neighbour list +""" + +import itertools +import numpy as np +import scipy.stats +import mindspore as ms +import mindspore.numpy as msnp +from mindspore.nn import Cell +from mindspore import Tensor +from mindspore import ops +from mindspore.ops import functional as F + +from ..function.functions import get_integer, displace_in_box + + +class GridNeighbours(Cell): + r""" + Neighbour list calculated by grids. + + Args: + cutoff (float): Cutoff distance. + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float32. + position coordinates of atoms in the simulation system. + pbc_box (Tensor): Tensor of shape (B, A, D). Data type is float32. + Box size of periodic boundary condition. Default: None + atom_mask (Tensor): Tensor of shape (B, A). Data type is bool\_. + Mask of atoms in the system. + Default: None + exclude_index (Tensor): Tensor of shape (B, A, Ex). Data type is int32. + Index of neighbour atoms which could be excluded from the neighbour list. + Default: None + num_neighbours (int): Number of neighbours. If input "None", this value will be calculated by + the ratio of the number of neighbouring grids to the total number of grids. + Default: None + cell_capacity (int): Capacity number of atoms in grid cell. If input "None", this value will be multiplied + by a factor of the maximum number of atoms in the grid cell at the initial coordinate. + Default: None + num_cell_cut (int): Number of subdivision of grid cells according to the cutoff. Default: 1 + cutoff_scale (float): Factor to scale the cutoff distance. Default: 1.2 + cell_cap_scale (float): Factor to scale "cell_capacity". Default: 1.25 + grid_num_scale (float): Scale factor to calculate "num_neighbours" by the ratio of grids. + If "num_neighbours" is not None, it will not be used. Default: 1.5 + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + B: Number of simulation walker. + A: Number of atoms in system. + D: Dimension of position coordinates. + Ex: Maximum number of excluded neighbour atoms. + """ + + def __init__(self, + cutoff: float, + coordinate: Tensor, + pbc_box: Tensor = None, + atom_mask: Tensor = None, + exclude_index: Tensor = None, + num_neighbours: int = None, + cell_capacity: int = None, + num_cell_cut: int = 1, + cutoff_scale: float = 1.2, + cell_cap_scale: float = 1.25, + grid_num_scale: float = 1.5, + ): + + super().__init__() + + self.num_atoms = coordinate.shape[-2] + self.dim = coordinate.shape[-1] + + self.cutoff = Tensor(cutoff, ms.float32) + + self.cutoff_scale = Tensor(cutoff_scale, ms.float32) + self.cell_cap_scale = Tensor(cell_cap_scale, ms.float32) + self.grid_num_scale = Tensor(grid_num_scale, ms.float32) + + # N_c + num_cell_cut = get_integer(num_cell_cut) + + self.grid_cutoff = self.cutoff / num_cell_cut + self.scaled_cutoff = self.cutoff * self.cutoff_scale + self.scaled_grid_cutoff = self.grid_cutoff * self.cutoff_scale + + if pbc_box is None: + self.use_pbc = False + # (B,1,D) <- (B,A,D) + rmax = msnp.max(coordinate, -2, keepdims=True) + rmin = msnp.min(coordinate, -2, keepdims=True) + center = msnp.mean(coordinate, -2, keepdims=True) + # (B,2,D) + rhalf = msnp.concatenate((rmax-center, center-rmin)) + # (B,D) + rhalf = msnp.max(rhalf, -2) + # (D) + rhalf = msnp.max(rhalf, 0) + box = rhalf * 2 + self.origin_grid_dims = msnp.ceil(box/self.scaled_grid_cutoff).astype(np.int32) + self.grid_dims = self.origin_grid_dims + 2 + box = self.origin_grid_dims * self.scaled_grid_cutoff + self.half_box = box / 2 + else: + self.use_pbc = True + center = None + # (B,D) + box = Tensor(pbc_box, ms.float32) + if box.ndim == 1: + box = F.expand_dims(pbc_box, 0) + self.half_box = box / 2 + if (self.cutoff > self.half_box).any(): + raise ValueError( + '"cutoff" cannot be greater than half the length of the shortest side of the PBC pbc_box!') + # (B,D) + self.origin_grid_dims = msnp.floor(box/self.scaled_grid_cutoff) + # (D) + self.origin_grid_dims = Tensor( + np.min(self.origin_grid_dims.asnumpy(), axis=0).astype(np.int32)) + self.grid_dims = self.origin_grid_dims + + # (D) + grid_mask = self.grid_dims > 3 + self.grid_dims = msnp.where(grid_mask, self.grid_dims, 1) + self.max_grid_index = self.origin_grid_dims - 1 + + # G + self.num_grids = int(np.prod(self.grid_dims.asnumpy())) + + # (D) + self.grid_factor = msnp.cumprod(self.grid_dims[::-1], axis=-1) + self.grid_factor = msnp.concatenate( + (self.grid_factor[1::-1], Tensor([1], ms.int32)), axis=-1) + + # (G,D) + grids = [np.arange(dim).tolist() for dim in self.grid_dims.asnumpy()] + grids = Tensor(tuple(itertools.product(*grids)), ms.int32) + + # (B,1,D) + box = F.expand_dims(box, -2) + if self.use_pbc: + # (B,1,D) = (B,D) / (D) + self.cell = box / self.grid_dims + if (self.cell < self.grid_cutoff).any(): + raise ValueError( + 'The cell length of cannot be smaller than cutoff!') + # (B,A,D) = ((B,A,D) - (D)) / (B,1,D) + atom_grid_idx = msnp.floor( + (displace_in_box(coordinate, pbc_box))/self.cell).astype(ms.int32) + else: + self.cell = msnp.broadcast_to(self.scaled_grid_cutoff, (self.dim,)) + # (B,A,D) = (B,A,D) - (B,1,D) + (D) + scaled_coord = (coordinate - center + + self.half_box) / self.scaled_grid_cutoff + scaled_coord = msnp.where(scaled_coord < 0, 0, scaled_coord) + atom_grid_idx = msnp.floor(scaled_coord).astype(ms.int32) + atom_grid_idx = msnp.where(atom_grid_idx < self.origin_grid_dims, + atom_grid_idx, self.max_grid_index) + atom_grid_idx += 1 + + # (B,A) <- (B,A,D) * (D) + atom_grid_idx = msnp.sum(atom_grid_idx * self.grid_factor, axis=-1) + + # (D): [n_1,n_2,...n_D] + num_extend_neigh = np.where(grid_mask.asnumpy(), num_cell_cut, 0) + dim_neigh_grids = num_extend_neigh * 2 + 1 + self.num_neigh_grids = int(np.prod(dim_neigh_grids)) + self.dim_neigh_grids = Tensor(dim_neigh_grids) + + if cell_capacity is None: + # (B,1) + _, max_num_in_cell = scipy.stats.mode( + atom_grid_idx.asnumpy(), axis=1) + max_num_in_cell = get_integer(np.max(max_num_in_cell)) + # C + cell_capacity = get_integer( + msnp.ceil(max_num_in_cell*self.cell_cap_scale)) + self.cell_capacity = int(min(cell_capacity, self.num_atoms)) + else: + self.cell_capacity = get_integer(cell_capacity) + + # N_cap = n * C + self.neigh_capacity = self.num_neigh_grids * self.cell_capacity + + # G*C + self.grid_cap = self.num_grids * self.cell_capacity + self.sort_id_factor = msnp.mod( + msnp.arange(self.num_atoms), self.cell_capacity) + + # (n,D) + neigh_offsets = [np.arange(-num_extend_neigh[i], num_extend_neigh[i]+1, + dtype=np.int32).tolist() for i in range(self.dim)] + neigh_offsets = Tensor( + tuple(itertools.product(*neigh_offsets)), ms.int32) + + if num_neighbours is None: + if self.use_pbc: + # N' = ceil(A * n / G * n_scale) + num_neighbours = msnp.ceil( + self.num_atoms*self.num_neigh_grids/self.num_grids*self.grid_num_scale).asnumpy() + # N = min(N',n*C) + self.num_neighbours = int(min(num_neighbours, self.num_atoms)) + else: + self.num_neighbours = int( + min(self.neigh_capacity, self.num_atoms)) + else: + self.num_neighbours = get_integer(num_neighbours) + if self.num_neighbours > self.num_atoms: + raise ValueError( + 'The value of "num_neighbours" cannot be larger than the number of atoms!') + + # (G,n,D) + neigh_grids = F.expand_dims(grids, -2) + neigh_offsets + # neigh_grids = msnp.select([neigh_grids<0, neigh_grids>=self.grid_dims], + # [neigh_grids+self.grid_dims, neigh_grids-self.grid_dims], neigh_grids) + neigh_grids = F.select( + neigh_grids < 0, neigh_grids+self.grid_dims, neigh_grids) + neigh_grids = F.select(neigh_grids >= self.grid_dims, neigh_grids-self.grid_dims, neigh_grids) + + # (G*n) + self.neigh_idx = msnp.sum( + neigh_grids*self.grid_factor, axis=-1).reshape(-1) + self.atom_idx = msnp.arange( + self.num_atoms).reshape(1, self.num_atoms, 1) + + if atom_mask is None: + self.atom_mask = None + else: + # (B,A) + self.atom_mask = Tensor(atom_mask, ms.bool_) + if self.atom_mask.shape[-1] != self.num_atoms: + raise ValueError('The number of atoms in atom_mask ('+str(self.atom_mask.shape[-1]) + + ') is mismatch with that in coordinate ('+str(self.num_atoms)+').') + if self.atom_mask.ndim == 1: + self.atom_mask = F.expand_dims(self.atom_mask, 0) + + if exclude_index is None: + self.exclude_index = None + else: + # (B,A,Ex) + self.exclude_index = Tensor(exclude_index, ms.int32) + if self.exclude_index.shape[-2] != self.num_atoms: + raise ValueError('The number of atoms in exclude_index ('+str(self.exclude_index.shape[-2]) + + ') is mismatch with that in coordinate ('+str(self.num_atoms)+').') + if self.exclude_index.ndim == 2: + self.exclude_index = F.expand_dims(self.exclude_index, 0) + + self.sort = ops.Sort(-1) + self.reduce_all = ops.ReduceAll() + + def set_exclude_index(self, exclude_index: Tensor): + """ + set excluded neighbour index. + + Args: + exclude_index (Tensor): Tensor of excluded neighbour indexes. + """ + # (B,A,Ex) + self.exclude_index = Tensor(exclude_index, ms.int32) + if self.exclude_index.shape[-2] != self.num_atoms: + raise ValueError('The number of atoms in exclude_index ('+str(self.exclude_index.shape[-2]) + + ') is mismatch with that in coordinate ('+str(self.num_atoms)+').') + if self.exclude_index.ndim == 2: + self.exclude_index = F.expand_dims(self.exclude_index, 0) + return self + + def check_neighbours_number(self, grid_neigh_atoms: Tensor, num_neighbours: int = None): + """ + check number of neighbours in neighbour list. + + Args: + grid_neigh_atoms (Tensor): Tensor of grid of neighbour atoms. + num_neighbours (int): Number of neighbours. + """ + if num_neighbours is None: + num_neighbours = self.num_neighbours + max_neighbours = msnp.sum(grid_neigh_atoms != self.num_atoms, axis=-1) + max_neighbours = F.cast( + msnp.max(F.cast(max_neighbours, ms.float32)), ms.int32) + if max_neighbours > num_neighbours: + print( + '================================================================================') + print( + 'Warning! Warning! Warning! Warning! Warning! Warning! Warning! Warning! Warning!') + print( + '--------------------------------------------------------------------------------') + print('The max number of neighbour atoms ' + 'is larger than that in neighbour list!') + print('The max number of neighbour atoms:') + print(max_neighbours) + print('The number of neighbour atoms in neighbour list:') + print(num_neighbours) + print('Please increase the value of grid_num_scale or num_neighbours!') + print( + '================================================================================') + return self + + def print_info(self): + """print information of neighbour list""" + print('Calculate neighbour list from grids') + print(' Cutoff distance: '+str(self.cutoff)) + print(' Grid cell length: '+str(self.scaled_grid_cutoff)) + print(' Initial size of grid cell: '+str(F.squeeze(self.cell))) + print(' Grid dimensions: '+str(self.grid_dims)) + print(' Number of Grids: '+str(self.num_grids)) + print(' Grid cell capacity: '+str(self.cell_capacity)) + print(' Dimension of neighbour cells: '+str(self.dim_neigh_grids)) + print(' Number of atoms: '+str(self.num_atoms)) + print(' Max number of neighbour atoms: '+str(self.num_neighbours)) + return self + + def get_neighbours_from_grids(self, atom_grid_idx: Tensor, num_neighbours: int): + """ + get neighbour list from grids + + Args: + atom_grid_idx (Tensor): Tensor of atoms grid indexes. + num_neighbours (int): Number of neighbours. + + Returns: + list, neighbour list from grids. + """ + #pylint: disable=unused-argument + # Sorted grid index + # (B,A) + sorted_grid_idx, sort_arg = self.sort(F.cast(atom_grid_idx, ms.float32)) + sorted_grid_idx = F.cast(sorted_grid_idx, ms.int32) + sorted_grid_idx = sorted_grid_idx * self.cell_capacity + self.sort_id_factor + + num_walker = atom_grid_idx.shape[0] + # Atom index in each grid + # (B,G*C) + scatter_shape = (num_walker, self.grid_cap) + grid_atoms = msnp.full(scatter_shape, self.num_atoms) + if num_walker == 1: + grid_atoms[:, sorted_grid_idx[0]] = sort_arg + else: + # (B,1,1) + batch_idx = msnp.arange(num_walker).reshape(num_walker, 1, 1) + # (B,A,1) + batch_idx = msnp.broadcast_to( + batch_idx, (num_walker, self.num_atoms, 1)) + # (B,A,2) + scatter_idx = msnp.concatenate( + (batch_idx, F.expand_dims(sorted_grid_idx, -1)), axis=-1) + grid_atoms = F.tensor_scatter_update( + grid_atoms, scatter_idx, sort_arg) + # (B,G,C) + grid_atoms = F.reshape( + grid_atoms, (num_walker, self.num_grids, self.cell_capacity)) + + # Atom index in neighbour grids for each grid + # (B,G*n,C) + grid_neigh_atoms = F.gather(grid_atoms, self.neigh_idx, -2) + # (B,G,n,C) + shape = (num_walker, self.num_grids, + self.num_neigh_grids, self.cell_capacity) + grid_neigh_atoms = F.reshape(grid_neigh_atoms, shape) + # (B,G,n*C) + shape = (num_walker, self.num_grids, + self.num_neigh_grids*self.cell_capacity) + grid_neigh_atoms = F.reshape(grid_neigh_atoms, shape) + grid_neigh_atoms, _ = self.sort(F.cast(grid_neigh_atoms, ms.float32)) + grid_neigh_atoms = F.cast(grid_neigh_atoms, ms.int32) + + self.check_neigbours_number(grid_neigh_atoms, num_neighbours) + grid_neigh_atoms = grid_neigh_atoms[..., :num_neighbours] + + # neighbour atoms for each atom + # (B,A,N) + if num_walker == 1: + return grid_neigh_atoms[:, atom_grid_idx[0], :] + return msnp.take_along_axis(grid_neigh_atoms, F.expand_dims(atom_grid_idx, -1), -2) + + def construct(self, + coordinate: Tensor, + pbc_box: Tensor = None, + atom_mask: Tensor = None, + exclude_index: Tensor = None, + ): + """ + Calculate neighbour list. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Atom coordinates. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + PBC box.Default: None + atom_mask (Tensor): Tensor of shape (B, A). Data type is bool. + Mask of atoms. Default: None + exclude_index (Tensor): Tensor of shape (B, A, Ex). Data type is int. + Index of atoms that should be exclude from neighbour list. + Default: None + + Returns: + - neighbours(Tensor). + - mask(Tensor). + + Sysmbols: + B: Number of simulation walker. + A: Number of atoms in system. + D: Dimension of position coordinates. + Ex: Maximum number of excluded neighbour atoms. + """ + + if self.use_pbc: + if pbc_box is None: + cell = self.cell + else: + # (B,1,D) = (B,D) / (D) + cell = F.expand_dims(pbc_box/self.grid_dims, -2) + if (cell < self.grid_cutoff).any(): + print('Warning! The cell length is smaller than cutoff') + # (B,A,D) = ((B,A,D) - (D)) / (B,1,D) + atom_grid_idx = msnp.floor( + (displace_in_box(coordinate, pbc_box))/cell).astype(ms.int32) + else: + # (B,1,D) <- (B,A,D) + center = msnp.mean(coordinate, -2, keepdims=True) + # (B,A,D) = (B,A,D) - (B,1,D) + (D) + scaled_coord = (coordinate - center + + self.half_box) / self.scaled_grid_cutoff + scaled_coord = msnp.where(scaled_coord < 0, 0, scaled_coord) + atom_grid_idx = msnp.floor(scaled_coord).astype(ms.int32) + atom_grid_idx = msnp.where(atom_grid_idx < self.origin_grid_dims, + atom_grid_idx, self.max_grid_index) + atom_grid_idx += 1 + + # Grid index for each atom + # (B,A) <- (B,A,D) * (D) + atom_grid_idx = msnp.sum(atom_grid_idx * self.grid_factor, axis=-1) + + neighbours = self.get_neighbours_from_grids( + atom_grid_idx, self.num_neighbours) + + mask = neighbours != self.num_atoms + atom_idx = msnp.broadcast_to(self.atom_idx, neighbours.shape) + neighbours = F.select(mask, neighbours, atom_idx) + mask = (neighbours != atom_idx) + + if atom_mask is None: + atom_mask = self.atom_mask + + if exclude_index is None: + exclude_index = self.exclude_index + if exclude_index is not None: + # (B,A,N,Ex) = (B,A,N,1) != (B,1,1,Ex) + exmask = (F.expand_dims(neighbours, -1) != + F.expand_dims(exclude_index, -2)) + # (B,A,N) + exmask = self.reduce_all(exmask, -1) + mask = F.logical_and(mask, exmask) + + return neighbours, mask diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/partition/neighbourlist.py b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/neighbourlist.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0fd3316e7c927e64d2538b31128597f81709e4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/partition/neighbourlist.py @@ -0,0 +1,265 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Neighbour list +""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor, ms_function +from mindspore import Parameter +from mindspore import ops +from mindspore.ops import functional as F +from mindspore.nn import Cell + +from . import FullConnectNeighbours, DistanceNeighbours, GridNeighbours +from ..system import Molecule + + +class NeighbourList(Cell): + r""" + Neighbour list. + + Args: + system (Molecule): Simulation system. + cutoff (float): Cutoff distance. Default: None + update_steps (int): Steps of update frequency. Default: 20 + exclude_index (Tensor): Tensor of shape (B, A, Ex). Data type is int. + Index of neighbour atoms which could be excluded from the neighbour list. + Default: None + num_neighbours (int): Number of neighbours. If input "None", this value will be calculated by + the ratio of the number of neighbouring grids to the total number of grids. + Default: None + num_cell_cut (int): Number of subdivision of grid cells according to cutoff. Default: 1 + cutoff_scale (float): Factor to scale cutoff distance. Default: 1.2 + cell_cap_scale (float): Scale factor for "cell_capacity". Default: 1.25 + grid_num_scale (float): Scale factor to calculate "num_neighbours" by ratio of grids. + If "num_neighbours" is not None, it will not be used. Default: 1.5 + large_dis (float): A large number of distance to fill the default atoms. Default: 1e4 + use_grids (bool): Whether to use grids to calculate the neighbour list. Default: None + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + B: Number of simulation walker. + A: Number of atoms in system. + Ex: Maximum number of excluded neighbour atoms. + """ + + def __init__(self, + system: Molecule, + cutoff: float = None, + update_steps: int = 20, + exclude_index: Tensor = None, + num_neighbours: int = None, + num_cell_cut: int = 1, + cutoff_scale: float = 1.2, + cell_cap_scale: float = 1.25, + grid_num_scale: float = 2, + large_dis: float = 1e4, + use_grids: bool = None, + ): + + super().__init__() + + self.num_walker = system.num_walker + self.coordinate = system.get_coordinate() + self.num_atoms = self.coordinate.shape[-2] + self.dim = self.coordinate.shape[-1] + + self.pbc_box = system.get_pbc_box() + + self.atom_mask = system.atom_mask + self.exclude_index = exclude_index + if exclude_index is not None: + self.exclude_index = Tensor(exclude_index, ms.int32) + + large_dis = Tensor(large_dis, ms.float32) + self.units = system.units + self.use_grids = use_grids + + self.no_mask = False + if cutoff is None: + self.cutoff = None + self.neighbour_list = FullConnectNeighbours(self.num_atoms) + if self.exclude_index is None: + self.no_mask = True + else: + self.cutoff = Tensor(cutoff, ms.float32) + if self.use_grids or self.use_grids is None: + self.neighbour_list = GridNeighbours( + cutoff=self.cutoff, + coordinate=self.coordinate, + pbc_box=self.pbc_box, + atom_mask=self.atom_mask, + exclude_index=self.exclude_index, + num_neighbours=num_neighbours, + num_cell_cut=num_cell_cut, + cutoff_scale=cutoff_scale, + cell_cap_scale=cell_cap_scale, + grid_num_scale=grid_num_scale, + ) + if self.neighbour_list.neigh_capacity >= self.num_atoms: + if self.use_grids: + print('Warning! The number of neighbour atoms in GridNeighbours (' + + str(self.neighbour_list.neigh_capacity) + + ') is not less than the number of atoms ('+str(self.num_atoms) + + '. It would be more efficient to use "DistanceNeighbours"' + ' (set "use_grids" to False or None).') + else: + self.use_grids = False + else: + self.use_grids = True + + if not self.use_grids: + if num_neighbours is None and self.pbc_box is not None: + op = ms.ops.ReduceProd(keep_dims=True) + volume = msnp.min(op(self.pbc_box, -1)) + num_neighbours = grid_num_scale * self.num_atoms * \ + msnp.power(cutoff*cutoff_scale, self.dim) / volume + num_neighbours = num_neighbours.astype(ms.int32) + + self.neighbour_list = DistanceNeighbours( + cutoff=self.cutoff, + num_neighbours=num_neighbours, + atom_mask=self.atom_mask, + exclude_index=self.exclude_index, + use_pbc=True, + cutoff_scale=cutoff_scale, + large_dis=large_dis + ) + + self.num_neighbours = self.neighbour_list.num_neighbours + + self.update_steps = update_steps + if update_steps <= 0: + raise ValueError('update_steps must be larger than 0!') + + index, mask = self.calcaulate(self.coordinate, self.pbc_box) + + self.neighbours = Parameter( + index, name='neighbours', requires_grad=False) + if self.cutoff is None and self.exclude_index is None: + self.neighbour_mask = None + else: + self.neighbour_mask = Parameter( + mask, name='neighbour_mask', requires_grad=False) + + self.identity = ops.Identity() + + def set_exclude_index(self, exclude_index: Tensor): + """ + set exclude index. + + Args: + exclude_index (Tensor): Tensor of exclude indexes. + + Returns: + bool. + """ + if exclude_index is None: + return True + self.exclude_index = Tensor(exclude_index, ms.int32) + self.neighbour_list.set_exclude_index(exclude_index) + index, mask = self.calcaulate(self.coordinate, self.pbc_box) + success = True + success = F.depend(success, F.assign(self.neighbours, index)) + if self.neighbour_mask is None: + self.neighbour_mask = Parameter( + mask, name='neighbour_mask', requires_grad=False) + else: + success = F.depend(success, F.assign(self.neighbour_mask, mask)) + return success + + def print_info(self): + """print information of neighbour list""" + self.neighbour_list.print_info() + return self + + @ms_function + def calcaulate(self, coordinate: Tensor, pbc_box: Tensor = None): + """ + calculate neighbour list. + + Args: + coordinate (Tensor): Tensor of coordinates. + pbc_box (Tensor): Tensor of PBC box. + + Returns: + - index(Tensor). + - mask(Tensor). + """ + coordinate = F.stop_gradient(coordinate) + pbc_box = F.stop_gradient(pbc_box) + if self.cutoff is None: + return self.neighbour_list(self.atom_mask, self.exclude_index) + + if self.use_grids: + return self.neighbour_list(coordinate, pbc_box) + + _, index, mask = self.neighbour_list( + coordinate, pbc_box, self.atom_mask, self.exclude_index) + + return index, mask + + def get_neighbour_list(self): + """ + get neighbour list. + + Returns: + - index(Tensor). + - mask(Tensor). + """ + index = self.identity(self.neighbours) + mask = None + if self.neighbour_mask is not None: + mask = self.identity(self.neighbour_mask) + return index, mask + + def construct(self, coordinate: Tensor, pbc_box: Tensor = None) -> bool: + r""" + Gather coordinate of neighbours atoms. + + Args: + coordinate (Tensor): Tensor of shape (B,A,D). Data type is float. + pbc_box (Tensor): Tensor of shape (B,D). Data type is float. + + Returns: + - neighbours (Tensor), Tensor of shape (B,A,N). Data type is int. + - neighbour_mask (Tensor or None), Tensor of shape (B,A,N). Data type is bool. + + Symbols: + B: Number of simulation walker. + A: Number of atoms in system. + N: Number of neighbour atoms. + D: Dimension of position coordinates. + Ex: Maximum number of excluded neighbour atoms. + """ + + neighbours, neighbour_mask = self.calcaulate(coordinate, pbc_box) + success = True + success = F.depend(success, F.assign(self.neighbours, neighbours)) + if self.neighbour_mask is not None: + success = F.depend(success, F.assign(self.neighbour_mask, neighbour_mask)) + return success diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6237a1abe7ca9bafefa8e1d02cb5d1e5d21340d7 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Pipeline""" +from .pipeline import PipeLine diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f11cfdc6242670f128804b599bb0cdd64d299ce --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Cell""" +from .basic import Attention, GlobalAttention +from .msa import MSARowAttentionWithPairBias, MSAColumnAttention, MSAColumnGlobalAttention +from .triangle import TriangleAttention, TriangleMultiplication, OuterProductMean +from .equivariant import InvariantPointAttention +from .transition import Transition + +__all__ = ['Attention', 'GlobalAttention', 'MSARowAttentionWithPairBias', + 'MSAColumnAttention', 'MSAColumnGlobalAttention', + 'TriangleAttention', 'TriangleMultiplication', 'OuterProductMean', + 'InvariantPointAttention', 'Transition'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/amp.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/amp.py new file mode 100644 index 0000000000000000000000000000000000000000..06bdd40108c278e2f026d611312fd086c1325112 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/amp.py @@ -0,0 +1,49 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""amp""" + +import mindspore.common.dtype as mstype +from mindspore import nn +from mindspore.ops import functional as F + + +class OutputTo16(nn.Cell): + "Wrap cell for amp. Cast network output back to float16" + + def __init__(self, op): + super(OutputTo16, self).__init__(auto_prefix=False) + self._op = op + + def construct(self, *x): + return F.cast(self._op(*x), mstype.float16) + + +def amp_convert(network, white_list=None): + """Do keep cell fp32.""" + network.to_float(mstype.float16) + if white_list is not None: + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, white_list): + network._cells[name] = OutputTo16(subcell.to_float(mstype.float32)) + change = True + else: + amp_convert(subcell, white_list) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/basic.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..2b496452ab78105a570d904295cb202095790a9f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/basic.py @@ -0,0 +1,428 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""basic""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from .initializer import glorot_uniform + + +class Attention(nn.Cell): + r""" + This is an implementation of multihead attention in the paper `Attention is all you need + `_. Given the query vector with source length, + and the key with key length and the target length, the attention will be performed as + the following. + + .. math:: + + Attention(query, key, vector) = Concat(head_1, \dots, head_h)W^O + + where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias. + + if query, key and value tensor is same, then it will be modified version of self + attention. + + Args: + num_head(int): The number of the heads. + hidden_size(int): The hidden size of the input. + gating(bool): Indicator of if the attention is gated. + q_data_dim(int): The last dimension length of the query tensor. + m_data_dim(int): The last dimension length of the key and value tensor. + output_dim(int): The last dimension length of the output tensor. + batch_size(int): The batch size of parameters in attention, used in while + control flow. Default: None. + + Inputs: + - **q_data** (Tensor) - The query tensor with shape (batch_size, + query_seq_length, q_data_dim) with query_seq_length the query sequence length. + - **m_data** (Tensor) - The key/value tensor with shape (batch_size, + value_seq_length, m_data_dim) with value_seq_length the value sequence length. + - **attention_mask** (Tensor) - The mask for attention matrix with shape + (batch_size, num_head, query_seq_length, value_seq_length). + - **index** (Tensor) - The index of while loop, only used in case of while + control flow. Default: None. + - **nonbatched_bias** (Tensor) - Non-batched bias for the attention matrix with + shape(num_heads, query_seq_length, value_seq_length). Default: None. + + Outputs: + Tensor, output tensor of the Attention layer with shape (batch_size, + query_seq_length, hidden_size). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import Attention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = Attention(num_head=4, hidden_size=64, gating=True, q_data_dim=64, + ... m_data_dim=64, output_dim=64) + >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) + >>> m_data = Tensor(np.ones((32, 256, 64)), mstype.float32) + >>> attention_mask = Tensor(np.ones((32, 4, 128, 256)), mstype.float32) + >>> attn_out= model(q_data, m_data, attention_mask) + >>> print(attn_out.shape) + (32, 128, 64) + """ + + def __init__(self, num_head, hidden_size, gating, q_data_dim, m_data_dim, output_dim, + batch_size=None): + super(Attention, self).__init__() + self.q_data_dim = q_data_dim + self.m_data_dim = m_data_dim + self.output_dim = output_dim + self.num_head = num_head + self.gating = gating + self.hidden_size = hidden_size + self.dim_per_head = self.hidden_size // self.num_head + self.batch_size = batch_size + self.matmul = P.MatMul(transpose_b=True) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.softmax = nn.Softmax() + self.sigmoid = nn.Sigmoid() + self.batch_size = batch_size + self._init_parameter() + + def construct(self, q_data, m_data, attention_mask, index=None, nonbatched_bias=None): + '''construct''' + if self.batch_size: + linear_q_weight = P.Gather()(self.linear_q_weights, index, 0) + linear_k_weight = P.Gather()(self.linear_k_weights, index, 0) + linear_v_weight = P.Gather()(self.linear_v_weights, index, 0) + linear_output_weight = P.Gather()(self.linear_output_weights, index, 0) + o_bias = P.Gather()(self.o_biases, index, 0) + linear_gating_weight = 0 + gating_bias = 0 + if self.gating: + linear_gating_weight = P.Gather()(self.linear_gating_weights, index, 0) + gating_bias = P.Gather()(self.gating_biases, index, 0) + else: + linear_q_weight = self.linear_q_weights + linear_k_weight = self.linear_k_weights + linear_v_weight = self.linear_v_weights + linear_output_weight = self.linear_output_weights + o_bias = self.o_biases + linear_gating_weight = 0 + gating_bias = 0 + if self.gating: + linear_gating_weight = self.linear_gating_weights + gating_bias = self.gating_biases + + dim_b, dim_q, dim_a = q_data.shape + _, dim_k, dim_c = m_data.shape + dim_h = self.num_head + + q_data = P.Reshape()(q_data, (-1, dim_a)) + m_data = P.Reshape()(m_data, (-1, dim_c)) + + q = self.matmul(q_data, linear_q_weight) * self.dim_per_head ** (-0.5) + k = self.matmul(m_data, linear_k_weight) + v = self.matmul(m_data, linear_v_weight) + + q = P.Reshape()(q, (dim_b, dim_q, dim_h, -1)) + k = P.Reshape()(k, (dim_b, dim_k, dim_h, -1)) + v = P.Reshape()(v, (dim_b, dim_k, dim_h, -1)) + + tmp_q = P.Transpose()(q, (0, 2, 1, 3)) + tmp_k = P.Transpose()(k, (0, 2, 1, 3)) + logits = P.Add()(self.batch_matmul_trans_b(tmp_q, tmp_k), attention_mask) + + if nonbatched_bias is not None: + bias = P.ExpandDims()(nonbatched_bias, 0) + logits = P.Add()(logits, bias) + weights = self.softmax(logits) + tmp_v = P.Transpose()(v, (0, 2, 3, 1)) + + weighted_avg = P.Transpose()(self.batch_matmul_trans_b(weights, tmp_v), (0, 2, 1, 3)) + + if self.gating: + gating_bias = P.ExpandDims()(P.ExpandDims()(gating_bias, 0), 0) + gate_values = P.Add()(P.Reshape()(self.matmul(q_data, linear_gating_weight), + (dim_b, dim_q, dim_h, -1)), + gating_bias) + gate_values = self.sigmoid(gate_values) + weighted_avg = P.Reshape()(weighted_avg * gate_values, (dim_b * dim_q, -1)) + + weighted_avg = P.Reshape()(weighted_avg, (dim_b * dim_q, -1)) + output = P.Add()(P.Reshape()(self.matmul(weighted_avg, linear_output_weight), + (dim_b, dim_q, -1)), + P.ExpandDims()(o_bias, 0)) + return output + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.linear_q_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.q_data_dim]), mstype.float32)) + self.linear_k_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.m_data_dim]), mstype.float32)) + self.linear_v_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * self.dim_per_head, + self.m_data_dim]), mstype.float32)) + self.linear_output_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.output_dim, + self.num_head * \ + self.dim_per_head]), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim]), + mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter(Tensor(np.zeros([self.batch_size, + self.num_head * \ + self.dim_per_head, + self.q_data_dim]), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, + self.num_head, + self.dim_per_head)), + mstype.float32), name="gating_b") + else: + self.linear_q_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.q_data_dim, self.dim_per_head * self.q_data_dim, + [self.num_head * self.dim_per_head, self.q_data_dim]), + mstype.float32)) + self.linear_k_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.m_data_dim, self.dim_per_head * self.m_data_dim, + [self.num_head * self.dim_per_head, self.m_data_dim]), + mstype.float32)) + self.linear_v_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.m_data_dim, self.dim_per_head * self.m_data_dim, + [self.num_head * self.dim_per_head, self.m_data_dim]), + mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros([self.output_dim, self.num_head * self.dim_per_head]), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros([self.output_dim]), mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter( + Tensor(np.zeros([self.num_head * self.dim_per_head, self.q_data_dim]), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.ones((self.num_head, self.dim_per_head)), + mstype.float32), + name="gating_b") + + +class GlobalAttention(nn.Cell): + r""" + This is an implementation of global gated self attention in the paper `Highly accurate + protein structure prediction with AlphaFold + `_. For this attention, the + shape of the query tensor, key tensor and the value tensor should be the same. + + Args: + num_head(int): The number of the heads. + gating(bool): Indicator of if the attention is gated. + input_dim(int): The last dimension length of the input tensor. + output_dim(int): The last dimension length of the output tensor. + batch_size(int): The batch size of parameters in attention, used in while control + flow. Default: None. + + Inputs: + - **q_data** (Tensor) - The query tensor with shape (batch_size, seq_length, + input_dim) with seq_length the sequence length. + - **m_data** (Tensor) - The key/value tensor with shape (batch_size, seq_length, + input_dim). + - **q_mask** (Tensor) - A binary mask for q_data of shape (batch_size, + seq_length, 1). + - **bias** (Tensor) - Bias for the attention matrix. Default: None. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. Default: None. + + Outputs: + Tensor, Output tensor of the GlobalAttention layer with shape (batch_size, seq_length, output_dim). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import GlobalAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = GlobalAttention(num_head=4, input_dim=64, gating=True, output_dim=256) + >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) + >>> m_data = Tensor(np.ones((32, 128, 64)), mstype.float32) + >>> q_mask = Tensor(np.ones((32, 128, 1)), mstype.float32) + >>> attn_out= model(q_data, m_data, q_mask) + >>> print(attn_out.shape) + (32, 128, 256) + """ + + def __init__(self, num_head, gating, input_dim, output_dim, batch_size=None): + super(GlobalAttention, self).__init__() + + self.input_dim = input_dim + self.num_head = num_head + self.dim_per_head = self.input_dim // self.num_head + self.output_dim = output_dim + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.batch_matmul = P.BatchMatMul() + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.matmul = P.MatMul() + self.softmax = nn.Softmax() + self.sigmoid = nn.Sigmoid() + self.gating = gating + self.batch_size = batch_size + self._init_parameter() + + def construct(self, q_data, m_data, q_mask, index=None): + '''construct''' + if self.batch_size: + q_weights = P.Gather()(self.linear_q_weights, index, 0) + k_weights = P.Gather()(self.linear_k_weights, index, 0) + v_weights = P.Gather()(self.linear_v_weights, index, 0) + output_weights = P.Gather()(self.linear_output_weights, index, 0) + output_bias = P.Gather()(self.o_biases, index, 0) + gating_weights = 0 + gating_bias = 0 + if self.gating: + gating_weights = P.Gather()(self.linear_gating_weights, index, 0) + gating_bias = P.Gather()(self.gating_biases, index, 0) + else: + q_weights = self.linear_q_weights + k_weights = self.linear_k_weights + v_weights = self.linear_v_weights + output_weights = self.linear_output_weights + output_bias = self.o_biases + gating_weights = 0 + gating_bias = 0 + if self.gating: + gating_weights = self.linear_gating_weights + gating_bias = self.gating_biases + + b, _, _ = m_data.shape + + v_weights = P.BroadcastTo((b, + self.dim_per_head * self.num_head, + self.dim_per_head))(v_weights) + v = self.batch_matmul(m_data, v_weights) + + mask_shape = q_mask.shape + value_shape = q_data.shape + broadcast_factor = 1. + value_size = value_shape[1] + mask_size = mask_shape[1] + if mask_size == 1: + broadcast_factor = broadcast_factor * value_size + qa = P.ReduceSum()(q_mask * q_data, 1) + qb = P.ReduceSum()(q_mask, 1) * broadcast_factor + 1e-10 + q_avg = P.RealDiv()(qa, qb) + + q = P.Reshape()(self.matmul(q_avg, q_weights), + (-1, self.num_head, self.dim_per_head)) * (self.dim_per_head ** (-0.5)) + + k_weights = P.BroadcastTo((b, + self.dim_per_head * self.num_head, + self.dim_per_head))(k_weights) + k = self.batch_matmul(m_data, k_weights) + + attention_mask = 1e9 * (P.Transpose()(q_mask, (0, 2, 1)) - 1.0) + logits = P.Add()(self.batch_matmul_trans_b(q, k), attention_mask) + + weights = self.softmax(logits) + weighted_avg = self.batch_matmul(weights, v) + + if self.gating: + q_data_shape = P.Shape()(q_data) + if len(q_data_shape) != 2: + q_data = P.Reshape()(q_data, (-1, q_data_shape[-1])) + out_shape = q_data_shape[:-1] + (-1,) + gate_values = P.Reshape()(self.matmul_trans_b(q_data, gating_weights) + gating_bias, + out_shape) + + gate_values = P.Reshape()(self.sigmoid(gate_values), + (b, -1, self.num_head, self.dim_per_head)) + weighted_avg = P.Reshape()(P.ExpandDims()(weighted_avg, 1) * gate_values, + (-1, self.num_head * self.dim_per_head)) + weighted_avg_shape = P.Shape()(weighted_avg) + if len(weighted_avg_shape) != 2: + weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) + output = P.Reshape()(P.Add()(self.matmul_trans_b(weighted_avg, + output_weights), output_bias), + (b, -1, self.output_dim)) + else: + weighted_avg = P.Reshape()(weighted_avg, (-1, self.num_head * self.dim_per_head)) + weighted_avg_shape = P.Shape()(weighted_avg) + if len(weighted_avg_shape) != 2: + weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) + out_shape = weighted_avg_shape[:-1] + (-1,) + output = P.Reshape()(P.Add()(self.matmul_trans_b(weighted_avg, output_weights), + output_bias), out_shape) + output = P.ExpandDims()(output, -1) + return output + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.linear_q_weights = Parameter( + Tensor(np.zeros((self.batch_size, + self.input_dim, + self.num_head, + self.dim_per_head)), + mstype.float32)) + self.linear_k_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim, self.dim_per_head)), + mstype.float32)) + self.linear_v_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim, self.dim_per_head)), + mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros((self.batch_size, + self.output_dim, + self.num_head * self.dim_per_head)), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.output_dim)), + mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter( + Tensor(np.zeros((self.batch_size, + self.num_head * self.dim_per_head, + self.input_dim)), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.input_dim)), + mstype.float32)) + else: + self.linear_q_weights = Parameter(Tensor( + glorot_uniform(self.num_head * self.input_dim, + self.dim_per_head * self.input_dim, + (self.input_dim, self.num_head*self.dim_per_head)), + mstype.float32)) + self.linear_k_weights = Parameter( + Tensor(glorot_uniform(self.input_dim, + self.dim_per_head, + (1, self.input_dim, self.dim_per_head)), + mstype.float32)) + self.linear_v_weights = Parameter( + Tensor(glorot_uniform(self.input_dim, + self.dim_per_head, + (1, self.input_dim, self.dim_per_head)), + mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros((self.output_dim, self.num_head * self.dim_per_head)), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros((self.output_dim)), + mstype.float32)) + if self.gating: + self.linear_gating_weights = Parameter( + Tensor(np.zeros((self.num_head * self.dim_per_head, self.input_dim)), + mstype.float32)) + self.gating_biases = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32)) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/equivariant.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/equivariant.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7d59299199a8520dffa62120f58bafa5f7505f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/equivariant.py @@ -0,0 +1,244 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Equivariant""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from ...common.geometry import apply_to_point, invert_point +from .initializer import lecun_init + + +class InvariantPointAttention(nn.Cell): + r""" + Invariant Point attention module. + This module is used to update the sequence representation ,which is the first input--inputs_1d, + adding location information to the sequence representation. + + The attention consists of three parts, namely, q, k, v obtained by the sequence representation, + q'k'v' obtained by the interaction between the sequence representation and the rigid body group, + and b , which is th bias, obtained from the pair representation (the second inputs -- inputs_2d). + + .. math:: + a_{ij} = Softmax(w_l(c_1{q_i}^Tk_j+b{ij}-c_2\sum {\left \| T_i\circ q'_i-T_j\circ k'_j \right \| ^{2 } }) + + where i and j represent the ith and jth amino acids in the sequence, respectively, + and T is the rotation and translation in the input. + + `Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention" + `_. + + Args: + num_head (int): The number of the heads. + num_scalar_qk (int): The number of the scalar query/key. + num_scalar_v (int): The number of the scalar value. + num_point_v (int): The number of the point value. + num_point_qk (int): The number of the point query/key. + num_channel (int): The number of the channel. + pair_dim (int): The last dimension length of pair. + + Inputs: + - **inputs_1d** (Tensor) - The first row of msa representation which is the output of evoformer module, + also called the sequence representation, shape :math:`[N_{res}, num\_channel]`. + - **inputs_2d** (Tensor) - The pair representation which is the output of evoformer module, + shape :math:`[N_{res}, N_{res}, pair\_dim]`. + - **mask** (Tensor) - A mask that determines which elements of inputs_1d are involved in the + attention calculation, shape :math:`[N_{res}, 1]` + - **rotation** (tuple) - A rotation term in a rigid body group T(r,t), + A tuple of length 9, The shape of each elements in the tuple is :math:`[N_{res}]`. + - **translation** (tuple) - A translation term in a rigid body group T(r,t), + A tuple of length 3, The shape of each elements in the tuple is :math:`[N_{res}]`. + + Outputs: + Tensor, the update of inputs_1d, shape :math:`[N_{res}, channel]`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import InvariantPointAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> import mindspore.context as context + >>> context.set_context(mode=context.GRAPH_MODE) + >>> model = InvariantPointAttention(num_head=12, num_scalar_qk=16, num_scalar_v=16, + ... num_point_v=8, num_point_qk=4, + ... num_channel=384, pair_dim=128) + >>> inputs_1d = Tensor(np.ones((256, 384)), mstype.float32) + >>> inputs_2d = Tensor(np.ones((256, 256, 128)), mstype.float32) + >>> mask = Tensor(np.ones((256, 1)), mstype.float32) + >>> rotation = tuple([Tensor(np.ones(256), mstype.float16) for _ in range(9)]) + >>> translation = tuple([Tensor(np.ones(256), mstype.float16) for _ in range(3)]) + >>> attn_out = model(inputs_1d, inputs_2d, mask, rotation, translation) + >>> print(attn_out.shape) + (256, 384) + """ + + def __init__(self, num_head, num_scalar_qk, num_scalar_v, num_point_v, num_point_qk, num_channel, pair_dim): + super(InvariantPointAttention, self).__init__() + + self._dist_epsilon = 1e-8 + self.num_head = num_head + self.num_scalar_qk = num_scalar_qk + self.num_scalar_v = num_scalar_v + self.num_point_v = num_point_v + self.num_point_qk = num_point_qk + self.num_channel = num_channel + self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 + \ + self.num_head * pair_dim + self.q_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_qk, + weight_init=lecun_init(self.num_channel)) + self.kv_scalar = nn.Dense(self.num_channel, self.num_head * (self.num_scalar_qk + self.num_scalar_v), + weight_init=lecun_init(self.num_channel)) + self.q_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_qk, + weight_init=lecun_init(self.num_channel) + ) + self.kv_point_local = nn.Dense(self.num_channel, self.num_head * 3 * (self.num_point_qk + self.num_point_v), + weight_init=lecun_init(self.num_channel)) + self.soft_max = nn.Softmax() + self.soft_plus = ops.Softplus() + self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights") + self.attention_2d = nn.Dense(pair_dim, self.num_head, weight_init=lecun_init(pair_dim)) + self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros' + ) + self.scalar_weights = Tensor(np.sqrt(1.0 / (3 * 16)).astype(np.float32)) + self.point_weights = Tensor(np.sqrt(1.0 / (3 * 18)).astype(np.float32)) + self.attention_2d_weights = Tensor(np.sqrt(1.0 / 3).astype(np.float32)) + + def construct(self, inputs_1d, inputs_2d, mask, rotation, translation): + '''construct''' + num_residues, _ = inputs_1d.shape + + # Improve readability by removing a large number of 'self's. + num_head = self.num_head + num_scalar_qk = self.num_scalar_qk + num_point_qk = self.num_point_qk + num_scalar_v = self.num_scalar_v + num_point_v = self.num_point_v + + # Construct scalar queries of shape: + q_scalar = self.q_scalar(inputs_1d) + q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk]) + + # Construct scalar keys/values of shape: + kv_scalar = self.kv_scalar(inputs_1d) + kv_scalar = mnp.reshape(kv_scalar, [num_residues, num_head, num_scalar_v + num_scalar_qk]) + k_scalar, v_scalar = mnp.split(kv_scalar, [num_scalar_qk], axis=-1) + + # Construct query points of shape: + # First construct query points in local frame. + q_point_local = self.q_point_local(inputs_1d) + + q_point_local = mnp.split(q_point_local, 3, axis=-1) + q_point_local = (ops.Squeeze()(q_point_local[0]), ops.Squeeze()(q_point_local[1]), + ops.Squeeze()(q_point_local[2])) + # Project query points into global frame. + q_point_global = apply_to_point(rotation, translation, q_point_local, 1) + + # Reshape query point for later use. + q_point0 = mnp.reshape(q_point_global[0], (num_residues, num_head, num_point_qk)) + q_point1 = mnp.reshape(q_point_global[1], (num_residues, num_head, num_point_qk)) + q_point2 = mnp.reshape(q_point_global[2], (num_residues, num_head, num_point_qk)) + + # Construct key and value points. + # Key points have shape [num_residues, num_head, num_point_qk] + # Value points have shape [num_residues, num_head, num_point_v] + + # Construct key and value points in local frame. + kv_point_local = self.kv_point_local(inputs_1d) + + kv_point_local = mnp.split(kv_point_local, 3, axis=-1) + kv_point_local = (ops.Squeeze()(kv_point_local[0]), ops.Squeeze()(kv_point_local[1]), + ops.Squeeze()(kv_point_local[2])) + # Project key and value points into global frame. + kv_point_global = apply_to_point(rotation, translation, kv_point_local, 1) + + kv_point_global0 = mnp.reshape(kv_point_global[0], (num_residues, num_head, (num_point_qk + num_point_v))) + kv_point_global1 = mnp.reshape(kv_point_global[1], (num_residues, num_head, (num_point_qk + num_point_v))) + kv_point_global2 = mnp.reshape(kv_point_global[2], (num_residues, num_head, (num_point_qk + num_point_v))) + + # Split key and value points. + k_point0, v_point0 = mnp.split(kv_point_global0, [num_point_qk], axis=-1) + k_point1, v_point1 = mnp.split(kv_point_global1, [num_point_qk], axis=-1) + k_point2, v_point2 = mnp.split(kv_point_global2, [num_point_qk], axis=-1) + + trainable_point_weights = self.soft_plus(self.trainable_point_weights) + point_weights = self.point_weights * mnp.expand_dims(trainable_point_weights, axis=1) + + v_point = [mnp.swapaxes(v_point0, -2, -3), mnp.swapaxes(v_point1, -2, -3), mnp.swapaxes(v_point2, -2, -3)] + q_point = [mnp.swapaxes(q_point0, -2, -3), mnp.swapaxes(q_point1, -2, -3), mnp.swapaxes(q_point2, -2, -3)] + k_point = [mnp.swapaxes(k_point0, -2, -3), mnp.swapaxes(k_point1, -2, -3), mnp.swapaxes(k_point2, -2, -3)] + + dist2 = mnp.square(ops.expand_dims(q_point[0], 2) - ops.expand_dims(k_point[0], 1)) + \ + mnp.square(ops.expand_dims(q_point[1], 2) - ops.expand_dims(k_point[1], 1)) + \ + mnp.square(ops.expand_dims(q_point[2], 2) - ops.expand_dims(k_point[2], 1)) + + attn_qk_point = -0.5 * mnp.sum(ops.expand_dims(ops.expand_dims(point_weights, 1), 1) * dist2, axis=-1) + + v = mnp.swapaxes(v_scalar, -2, -3) + q = mnp.swapaxes(self.scalar_weights * q_scalar, -2, -3) + k = mnp.swapaxes(k_scalar, -2, -3) + attn_qk_scalar = ops.matmul(q, mnp.swapaxes(k, -2, -1)) + attn_logits = attn_qk_scalar + attn_qk_point + + attention_2d = self.attention_2d(inputs_2d) + attention_2d = mnp.transpose(attention_2d, [2, 0, 1]) + attention_2d = self.attention_2d_weights * attention_2d + + attn_logits += attention_2d + + mask_2d = mask * mnp.swapaxes(mask, -1, -2) + attn_logits -= 50 * (1. - mask_2d) + + attn = self.soft_max(attn_logits) + + result_scalar = ops.matmul(attn, v) + + result_point_global = [mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[0][:, None, :, :], axis=-2), -2, -3), + mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[1][:, None, :, :], axis=-2), -2, -3), + mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[2][:, None, :, :], axis=-2), -2, -3) + ] + + result_point_global = [mnp.reshape(result_point_global[0], [num_residues, num_head * num_point_v]), + mnp.reshape(result_point_global[1], [num_residues, num_head * num_point_v]), + mnp.reshape(result_point_global[2], [num_residues, num_head * num_point_v])] + result_scalar = mnp.swapaxes(result_scalar, -2, -3) + + result_scalar = mnp.reshape(result_scalar, [num_residues, num_head * num_scalar_v]) + + result_point_local = invert_point(result_point_global, rotation, translation, 1) + + output_feature1 = result_scalar + output_feature20 = result_point_local[0] + output_feature21 = result_point_local[1] + output_feature22 = result_point_local[2] + + output_feature3 = mnp.sqrt(self._dist_epsilon + + mnp.square(result_point_local[0]) + + mnp.square(result_point_local[1]) + + mnp.square(result_point_local[2])) + + result_attention_over_2d = ops.matmul(mnp.swapaxes(attn, 0, 1), inputs_2d) + num_out = num_head * result_attention_over_2d.shape[-1] + output_feature4 = mnp.reshape(result_attention_over_2d, [num_residues, num_out]) + + final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21, + output_feature22, output_feature3, output_feature4], axis=-1) + final_result = self.output_projection(final_act) + return final_result diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/initializer.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..718aa5f11c6c34c5f3b8d425ede68f98ad0afd18 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/initializer.py @@ -0,0 +1,35 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""initializer""" + +import numpy as np +from mindspore.common.initializer import TruncatedNormal + +TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978, dtype=np.float32) + + +def lecun_init(fan_in, initializer_name='linear'): + """lecun init""" + scale = 1.0 + if initializer_name == 'relu': + scale *= 2 + weight_init = TruncatedNormal(sigma=np.sqrt(scale / fan_in) / TRUNCATED_NORMAL_STDDEV_FACTOR) + return weight_init + + +def glorot_uniform(fan_in, fan_out, weight_shape): + """glorot uniform""" + limit = np.sqrt(6 / (fan_in + fan_out)) + return np.random.uniform(-limit, limit, size=weight_shape) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/mask.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..56becdd0a85142d98f4b1aabbf59dc7de32308e7 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/mask.py @@ -0,0 +1,44 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Mask""" +from mindspore.ops import operations as P +from mindspore.ops import functional as F +import mindspore.nn as nn + + +class MaskedLayerNorm(nn.Cell): + '''masked_layer_norm''' + + def __init__(self): + super(MaskedLayerNorm, self).__init__() + self.norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + + def construct(self, act, gamma, beta, mask=None): + '''construct''' + act = act + gamma = gamma + beta = beta + + ones = P.Ones()(act.shape[:-1] + (1,), act.dtype) + if mask is not None: + mask = F.expand_dims(mask, -1) + mask = mask * ones + else: + mask = ones + + act = act * mask + act, _, _ = self.norm(act, gamma, beta) + act = act * mask + return act diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/msa.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/msa.py new file mode 100644 index 0000000000000000000000000000000000000000..841003c252916c71a153370cc9f59cad5bdb2deb --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/msa.py @@ -0,0 +1,357 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MSA""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from .basic import Attention, GlobalAttention +from .mask import MaskedLayerNorm +from ...common.utils import _memory_reduce + + +class MSARowAttentionWithPairBias(nn.Cell): + r""" + MSA row attention. Information from pair action value is made as the bias of the matrix of MSARowAttention, + in order to update the state of MSA using pair information. + + Reference: + `Jumper et al. (2021) Suppl. Alg. 7 'MSARowAttentionWithPairBias' + `_. + + Args: + num_head (int): The number of the attention head. + key_dim (int): The dimension of the attention hidden layer. + gating (bool): Indicator of if the attention is gated. + msa_act_dim (int): The dimension of the msa_act. + pair_act_dim (int): The dimension of the pair_act. + batch_size (int): The batch size of parameters in MSA row attention, used in while control flow. + Default: None. + slice_num (int): The number of slices to be made to reduce memory. Default: 0. + + Inputs: + - **msa_act** (Tensor) - Tensor of msa_act with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . + - **msa_mask** (Tensor) - The mask for MSA row attention matrix with shape :math:`(N_{seqs}, N_{res})` . + - **pair_act** (Tensor) - Tensor of pair_act with shape :math:`(N_{res}, N_{res}, pair\_act\_dim)` . + Data type is float. + - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". + - **norm_msa_mask** (Tensor) - The mask of msa_act when to do layernorm with shape :math:`(N_{seqs}, N_{res})`, + Default: "None". + - **norm_pair_mask** (Tensor) - The mask of pair_act when to do layernorm with shape :math:`(N_{res}, N_{res})`, + Default: "None". + - **res_idx** (Tensor) - The residue index used to perform ROPE with shape :math:`(N_{res})`, Default: "None". + + Outputs: + Tensor, the float tensor of the msa_act of the layer with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import MSARowAttentionWithPairBias + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = MSARowAttentionWithPairBias(num_head=4, key_dim=4, gating=True, + ... msa_act_dim=64, pair_act_dim=128, + ... batch_size=None) + >>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32) + >>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16) + >>> pair_act = Tensor(np.ones((256, 256, 128)), mstype.float32) + >>> index = None + >>> msa_out = model(msa_act, msa_mask, pair_act, index) + >>> print(msa_out.shape) + (4, 256, 64) + """ + + def __init__(self, num_head, key_dim, gating, msa_act_dim, pair_act_dim, batch_size=None, slice_num=0): + super(MSARowAttentionWithPairBias, self).__init__() + self.num_head = num_head + self.batch_size = batch_size + self.matmul = P.MatMul(transpose_b=True) + self.attn_mod = Attention(num_head, key_dim, gating, msa_act_dim, msa_act_dim, msa_act_dim, batch_size) + self.msa_act_dim = msa_act_dim + self.pair_act_dim = pair_act_dim + self.batch_size = batch_size + self.slice_num = slice_num + self.idx = Tensor(0, mstype.int32) + self.masked_layer_norm = MaskedLayerNorm() + self._init_parameter() + + def construct(self, msa_act, msa_mask, pair_act, index=None, norm_msa_mask=None, norm_pair_mask=None, res_idx=None): + '''construct''' + if self.batch_size: + query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) + query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) + feat_2d_norm_gamma = P.Gather()(self.feat_2d_norm_gammas, index, 0) + feat_2d_norm_beta = P.Gather()(self.feat_2d_norm_betas, index, 0) + feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0) + else: + query_norm_gamma = self.query_norm_gammas + query_norm_beta = self.query_norm_betas + feat_2d_norm_gamma = self.feat_2d_norm_gammas + feat_2d_norm_beta = self.feat_2d_norm_betas + feat_2d_weight = self.feat_2d_weights + + q, k, _ = pair_act.shape + input_bias = 1e9 * (msa_mask - 1.0) + input_bias = P.ExpandDims()(P.ExpandDims()(input_bias, 1), 2) + + msa_act = self.masked_layer_norm(msa_act, query_norm_gamma, query_norm_beta, mask=norm_msa_mask) + pair_act = self.masked_layer_norm(pair_act, feat_2d_norm_gamma, feat_2d_norm_beta, mask=norm_pair_mask) + pair_act = P.Reshape()(pair_act, (-1, pair_act.shape[-1])) + nonbatched_bias = P.Transpose()(P.Reshape()(self.matmul(pair_act, feat_2d_weight), (q, k, self.num_head)), + (2, 0, 1)) + batched_inputs = (msa_act, input_bias) + if res_idx is not None: + nonbatched_inputs = (nonbatched_bias, res_idx) + else: + nonbatched_inputs = (index, nonbatched_bias) + msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + return msa_act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) + self.feat_2d_norm_gammas = Parameter( + Tensor(np.zeros([self.batch_size, self.pair_act_dim]), mstype.float32)) + self.feat_2d_norm_betas = Parameter( + Tensor(np.zeros([self.batch_size, self.pair_act_dim]), mstype.float32)) + self.feat_2d_weights = Parameter( + Tensor(np.zeros([self.batch_size, self.num_head, self.pair_act_dim]), mstype.float32)) + else: + self.query_norm_gammas = Parameter(Tensor(np.ones([self.msa_act_dim]), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros([self.msa_act_dim]), mstype.float32)) + self.feat_2d_norm_gammas = Parameter(Tensor(np.ones([self.pair_act_dim]), mstype.float32)) + self.feat_2d_norm_betas = Parameter(Tensor(np.zeros([self.pair_act_dim]), mstype.float32)) + self.feat_2d_weights = Parameter( + Tensor(np.random.normal(scale=1 / np.sqrt(self.pair_act_dim), size=[self.num_head, self.pair_act_dim]), + mstype.float32)) + + def _compute(self, msa_act, mask, index, nonbatched_bias): + """ + compute. + + Args: + msa_act (Tensor): Tensor of msa_act. + mask (Tensor): The mask for MSA row attention matrix. + index (Tensor): The index of while loop, only used in case of while control flow. Default: None + nonbatched_bias(Tensor): Tensor of non batched bias matrix. + + Outputs: + - **msa_act** (Tensor)- Tensor, the float tensor of the msa_act of the attention layer. + """ + msa_act = self.attn_mod(msa_act, msa_act, mask, index, nonbatched_bias) + return msa_act + + +class MSAColumnAttention(nn.Cell): + """ + MSA column-wise gated self attention. + The column-wise attention lets the elements that belong to the same target residue exchange information. + + Reference: + `Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" + `_. + + Args: + num_head (int): The number of the heads. + key_dim (int): The dimension of the input. + gating (bool): Indicator of if the attention is gated. + msa_act_dim (int): The dimension of the msa_act. The intermediate variable after MSA retrieving + in AlphaFold. + batch_size (int): The batch size of parameters in MSAColumnAttention, used in while control flow, + Default: "None". + slice_num (int): The number of slices to be made to reduce memory, Default: 0. + + Inputs: + - **msa_act** (Tensor) - Tensor of msa_act. The intermediate variable after MSA retrieving + in AlphaFold, shape :math:`[N_{seqs}, N_{res}, C_m]` . + - **msa_mask** (Tensor) - The mask for MSAColumnAttention matrix, shape :math:`[N_{seqs}, N_{res}]`. + - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". + + Outputs: + Tensor, the float tensor of the msa_act of the layer, shape :math:`[N_{seqs}, N_{res}, C_m]`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import MSAColumnAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = MSAColumnAttention(num_head=8, key_dim=256, gating=True, + ... msa_act_dim=256, batch_size=1, slice_num=0) + >>> msa_act = Tensor(np.ones((512, 256, 256)), mstype.float32) + >>> msa_mask = Tensor(np.ones((512, 256)), mstype.float32) + >>> index = Tensor(0, mstype.int32) + >>> attn_out = model(msa_act, msa_mask, index) + >>> print(attn_out.shape) + (512, 256, 256) + """ + + def __init__(self, num_head, key_dim, gating, msa_act_dim, batch_size=None, slice_num=0): + super(MSAColumnAttention, self).__init__() + self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + self.attn_mod = Attention(num_head, key_dim, gating, msa_act_dim, msa_act_dim, msa_act_dim, batch_size) + self.batch_size = batch_size + self.slice_num = slice_num + self.msa_act_dim = msa_act_dim + self.idx = Tensor(0, mstype.int32) + self._init_parameter() + + def construct(self, msa_act, msa_mask, index=None): + '''construct''' + if self.batch_size: + query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) + query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) + else: + query_norm_gamma = self.query_norm_gammas + query_norm_beta = self.query_norm_betas + msa_act = P.Transpose()(msa_act, (1, 0, 2)) + msa_mask = P.Transpose()(msa_mask, (1, 0)) + + input_mask = 1e9 * (msa_mask - 1.) + input_mask = P.ExpandDims()(P.ExpandDims()(input_mask, 1), 2) + msa_act, _, _ = self.query_norm(msa_act, query_norm_gamma, query_norm_beta) + batched_inputs = (msa_act, input_mask) + nonbatched_inputs = (index,) + msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + msa_act = P.Transpose()(msa_act, (1, 0, 2)) + return msa_act + + def _init_parameter(self): + if self.batch_size: + self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) + else: + self.query_norm_gammas = Parameter(Tensor(np.ones([self.msa_act_dim]), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros([self.msa_act_dim]), mstype.float32)) + + def _compute(self, msa_act, input_mask, index): + '''compute''' + msa_act = self.attn_mod(msa_act, msa_act, input_mask, index) + return msa_act + + +class MSAColumnGlobalAttention(nn.Cell): + r""" + MSA column global attention. Transpose MSA information at sequence axis and residue axis, then use `GlobalAttention + ` to + do Attention between input sequences without dealing with the relationship between residues in sequence. + Comparing with MSAColumnAttention, it uses GlobalAttention to deal with longer input sequence. + + Reference: + `Jumper et al. (2021) Suppl. Alg. 19 'MSAColumnGlobalAttention' + `_. + + Args: + num_head (int): The number of the attention heads. + gating (bool): Indicator of if the attention is gated. + msa_act_dim (int): The dimension of the msa_act. + batch_size (int): The batch size of parameters in MSAColumnGlobalAttention, used + in while control flow. Default: None. + slice_num (int): The number of slices to be made to reduce memory. Default: 0 + + Inputs: + - **msa_act** (Tensor) - Tensor of msa_act with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . + - **msa_mask** (Tensor) - The mask for msa_act matrix with shape :math:`(N_{seqs}, N_{res})` . + - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". + + Outputs: + Tensor, the float tensor of the msa_act of the layer with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import MSAColumnGlobalAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = MSAColumnGlobalAttention(num_head=4, gating=True, msa_act_dim=64, batch_size=None) + >>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32) + >>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16) + >>> index = None + >>> msa_out = model(msa_act, msa_mask, index) + >>> print(msa_out.shape) + (4, 256, 64) + """ + + def __init__(self, num_head, gating, msa_act_dim, batch_size=None, slice_num=0): + super(MSAColumnGlobalAttention, self).__init__() + self.attn_mod = GlobalAttention(num_head, gating, msa_act_dim, msa_act_dim, batch_size) + self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + self.batch_size = batch_size + self.slice_num = slice_num + self.msa_act_dim = msa_act_dim + self.idx = Tensor(0, mstype.int32) + self._init_parameter() + + def construct(self, msa_act, msa_mask, index=None): + '''construct''' + if self.batch_size: + query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) + query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) + msa_act = P.Transpose()(msa_act, (1, 0, 2)) + msa_mask = P.Transpose()(msa_mask, (1, 0)) + else: + query_norm_gamma = self.query_norm_gammas + query_norm_beta = self.query_norm_betas + msa_act = P.Transpose()(msa_act, (1, 0, 2)) + msa_mask = P.Transpose()(msa_mask, (1, 0)) + + input_mask = 1e9 * (msa_mask - 1.) + input_mask = P.ExpandDims()(P.ExpandDims()(input_mask, 1), 2) + + msa_act, _, _ = self.query_norm(msa_act, + query_norm_gamma, + query_norm_beta) + msa_mask = P.ExpandDims()(msa_mask, -1) + batched_inputs = (msa_act, msa_mask) + nonbatched_inputs = (index,) + msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + msa_act = P.Transpose()(msa_act, (1, 0, 2)) + return msa_act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32)) + else: + self.query_norm_gammas = Parameter(Tensor(np.ones((self.msa_act_dim)), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros((self.msa_act_dim)), mstype.float32)) + + def _compute(self, msa_act, msa_mask, index): + """ + compute. + + Args: + msa_act (Tensor): Tensor of msa_act. + msa_mask (Tensor): The mask for msa_act matrix. + index (Tensor): The index of while loop, only used in case of while + control flow. Default: None + + Outputs: + - **msa_act** (Tensor)- Tensor, the float tensor of the msa_act of the attention layer. + """ + msa_act = self.attn_mod(msa_act, msa_act, msa_mask, index) + return msa_act diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/transition.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/transition.py new file mode 100644 index 0000000000000000000000000000000000000000..0e73a7fe8906fda35f04ac212ebb3d5a6d82d6df --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/transition.py @@ -0,0 +1,138 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transition""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore import Parameter +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from .initializer import lecun_init +from .mask import MaskedLayerNorm +from ...common.utils import _memory_reduce + + +class Transition(nn.Cell): + r""" + This is 2-layer MLP where the intermediate layer expands number of channels + of the input by a factor(num_intermediate_factor). + + .. math:: + Transition(\mathbf{act}) = Linear(Linear(\mathbf{act})) + + Args: + num_intermediate_factor(float): The expand factor of intermediate output + channels compared to the input. + input_dim(int): The channels of the input. + batch_size(int): The batch size of parameters in Transition, + used in while control flow. Default: "None". + slice_num (int): The slice num used in transition layer + when the memory is overflow. Default: 0. + + Inputs: + - **act** (Tensor) - The input with channels equal to input_dim, shape is (..., input_dim). + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. Default: "None". + - **mask** (Tensor) - The mask of act when to do layernorm with shape :math:`(32, input_{dim})`, + Default: "None". + + Outputs: + Tensor, the float tensor of the output of the layer with shape (..., input_dim). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import Transition + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = Transition(num_intermediate_factor=4, input_dim=128) + >>> input = Tensor(np.ones((32, 128, 128)), mstype.float32) + >>> output= model(input) + >>> print(output.shape) + (32, 128, 128) + """ + + def __init__(self, num_intermediate_factor, input_dim, batch_size=None, slice_num=0): + super(Transition, self).__init__() + self.matmul = P.MatMul(transpose_b=True) + self.input_dim = input_dim + self.num_intermediate = int(input_dim * num_intermediate_factor) + self.batch_size = batch_size + self.slice_num = slice_num + self.relu = nn.ReLU() + self.idx = Tensor(0, mstype.int32) + self.masked_layer_norm = MaskedLayerNorm() + self._init_parameter() + + def construct(self, act, index=None, mask=None): + '''Compute transition''' + if self.batch_size: + input_layer_norm_gamma = P.Gather()(self.input_layer_norm_gammas, index, 0) + input_layer_norm_beta = P.Gather()(self.input_layer_norm_betas, index, 0) + transition1_weight = P.Gather()(self.transition1_weights, index, 0) + transition1_bias = P.Gather()(self.transition1_biases, index, 0) + transition2_weight = P.Gather()(self.transition2_weights, index, 0) + transition2_bias = P.Gather()(self.transition2_biases, index, 0) + else: + input_layer_norm_gamma = self.input_layer_norm_gammas + input_layer_norm_beta = self.input_layer_norm_betas + transition1_weight = self.transition1_weights + transition1_bias = self.transition1_biases + transition2_weight = self.transition2_weights + transition2_bias = self.transition2_biases + act = self.masked_layer_norm(act, input_layer_norm_gamma, input_layer_norm_beta, mask=mask) + batched_inputs = (act,) + nonbatched_inputs = (transition1_weight, transition1_bias, transition2_weight, transition2_bias) + act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + return act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.input_layer_norm_gammas = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + self.input_layer_norm_betas = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + self.transition1_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate, self.input_dim)), mstype.float32)) + self.transition1_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate)), mstype.float32)) + self.transition2_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim, self.num_intermediate)), mstype.float32)) + self.transition2_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) + else: + self.input_layer_norm_gammas = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32)) + self.input_layer_norm_betas = Parameter(Tensor(np.zeros((self.input_dim)), mstype.float32)) + self.transition1_weights = Parameter(initializer(lecun_init(self.input_dim, initializer_name='relu'), + [self.num_intermediate, self.input_dim])) + self.transition1_biases = Parameter(Tensor(np.zeros((self.num_intermediate)), mstype.float32)) + self.transition2_weights = Parameter( + Tensor(np.zeros((self.input_dim, self.num_intermediate)), mstype.float32)) + self.transition2_biases = Parameter(Tensor(np.zeros((self.input_dim)), mstype.float32)) + + def _compute(self, act, transition1_weight, transition1_bias, transition2_weight, transition2_bias): + '''compute transition.''' + + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + act = self.relu(P.BiasAdd()(self.matmul(act, transition1_weight), transition1_bias)) + act = P.BiasAdd()(self.matmul(act, transition2_weight), transition2_bias) + act = P.Reshape()(act, act_shape) + return act diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/triangle.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/triangle.py new file mode 100644 index 0000000000000000000000000000000000000000..01e4ea5726c65a03272a9fdeb6ca16886d94d80f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/cell/triangle.py @@ -0,0 +1,516 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Triangle""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from .basic import Attention +from .initializer import lecun_init +from .mask import MaskedLayerNorm +from ...common.utils import _memory_reduce + + +class TriangleAttention(nn.Cell): + r""" + Triangle attention. for the detailed implementation process, refer to + `TriangleAttention `_. + + The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, + which is divided into three parts: projection, self-attention and output. Firstly, the amino acid pair is projected + to obtain the q, k, v, and then through the classic multi-head self-attention mechanism, add the relationship + between i, j, k triangle sides, finally output the result. + + Args: + orientation (int): Decide the dimension of Triangle attention, used as the starting and ending + edge of self-attention. + num_head (int): The number of the heads. + key_dim (int): The dimension of the hidden layer. + gating (bool): Indicator of if the attention is gated. + layer_norm_dim (int): The dimension of the layer_norm. + batch_size (int): The batch size of triangle attention, default: "None". + slice_num (int): The number of slices to be made to reduce memory, default: 0. + + Inputs: + - **pair_act** (Tensor) - Tensor of pair_act. shape :math:`(N_{res}, N_{res}, layer\_norm\_dim)` + - **pair_mask** (Tensor) - The mask for TriangleAttention matrix with shape. shape :math:`(N_{res}, N_{res})`. + - **index** (Tensor) - The index of while loop, only used in case of while control flow, Default: "None". + - **mask** (Tensor) - The mask of pair_act when to do layernorm with shape (N_{res}, N_{res}), Default: "None". + + Outputs: + Tensor, the float tensor of the pair_act of the layer with shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import TriangleAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = TriangleAttention(orientation="per_row", num_head=4, key_dim=64, gating=True, layer_norm_dim=64) + >>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32) + >>> input_1 = Tensor(np.ones((256, 256)), mstype.float32) + >>> out = model(input_0, input_1, index=0) + >>> print(out.shape) + (256, 256, 64) + """ + + def __init__(self, orientation, num_head, key_dim, gating, layer_norm_dim, batch_size=None, slice_num=0): + super(TriangleAttention, self).__init__() + self.num_head = num_head + self.orientation = orientation + self.orientation_is_per_column = (self.orientation == 'per_column') + self.init_factor = Tensor(1. / np.sqrt(layer_norm_dim), mstype.float32) + self.matmul = P.MatMul(transpose_b=True) + self.batchmatmul_b = P.BatchMatMul(transpose_b=True) + self.attn_mod = Attention(num_head, key_dim, gating, layer_norm_dim, layer_norm_dim, layer_norm_dim, + batch_size) + self.batch_size = batch_size + self.slice_num = slice_num + self.layer_norm_dim = layer_norm_dim + self.idx = Tensor(0, mstype.int32) + self.masked_layer_norm = MaskedLayerNorm() + self._init_parameter() + + def construct(self, pair_act, pair_mask, index=None, mask=None): + '''construct''' + if self.batch_size: + query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) + query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) + feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0) + else: + query_norm_gamma = self.query_norm_gammas + query_norm_beta = self.query_norm_betas + feat_2d_weight = self.feat_2d_weights + if self.orientation_is_per_column: + pair_act = P.Transpose()(pair_act, (1, 0, 2)) + pair_mask = P.Transpose()(pair_mask, (1, 0)) + + pair_mask = 1e9 * (pair_mask - 1.) + input_mask = P.ExpandDims()(P.ExpandDims()(pair_mask, 1), 2) + + pair_act = self.masked_layer_norm(pair_act, query_norm_gamma, query_norm_beta, mask) + + q, k, _ = pair_act.shape + nonbatched_bias = self.matmul(P.Reshape()(pair_act, (-1, pair_act.shape[-1])), feat_2d_weight) + nonbatched_bias = P.Transpose()(P.Reshape()(nonbatched_bias, (q, k, -1)), (2, 0, 1)) + + batched_inputs = (pair_act, input_mask) + nonbatched_inputs = (index, nonbatched_bias) + pair_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) + if self.orientation_is_per_column: + pair_act = P.Transpose()(pair_act, (1, 0, 2)) + return pair_act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.feat_2d_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_head, self.layer_norm_dim)), mstype.float32)) + else: + self.query_norm_gammas = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + self.query_norm_betas = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.feat_2d_weights = Parameter(Tensor( + np.random.normal(scale=1 / np.sqrt(self.layer_norm_dim), size=(self.num_head, self.layer_norm_dim)), + mstype.float32)) + + def _compute(self, pair_act, input_mask, index, nonbatched_bias): + '''compute traiangle''' + pair_act = self.attn_mod(pair_act, pair_act, input_mask, index, nonbatched_bias) + return pair_act + + +class TriangleMultiplication(nn.Cell): + r""" + Triangle multiplication layer. for the detailed implementation process, refer to + `TriangleMultiplication `_. + + The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, and + the result of the dot product between ik and jk is added to the edge of ij. + + Args: + num_intermediate_channel (float): The number of intermediate channel. + equation (str): The equation used in triangle multiplication layer. edge update forms + corresponding to 'incoming' and 'outgoing', + :math:`(ikc,jkc->ijc, kjc,kic->ijc)`. + layer_norm_dim (int): The last dimension length of the layer norm. + batch_size (int): The batch size of parameters in triangle multiplication. Default: None. + + Inputs: + - **pair_act** (Tensor) - Tensor of pair_act. shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + - **pair_mask** (Tensor) - The mask for TriangleAttention matrix with shape. shape :math:`(N{res}, N{res})`. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. + + Outputs: + Tensor, the float tensor of the pair_act of the layer with shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import TriangleMultiplication + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = TriangleMultiplication(num_intermediate_channel=64, + ... equation="ikc,jkc->ijc", layer_norm_dim=64, batch_size=0) + >>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32) + >>> input_1 = Tensor(np.ones((256, 256)), mstype.float32) + >>> out = model(input_0, input_1, index=0) + >>> print(out.shape) + (256, 256, 64) + """ + + def __init__(self, num_intermediate_channel, equation, layer_norm_dim, batch_size=None): + super(TriangleMultiplication, self).__init__() + self.num_intermediate_channel = num_intermediate_channel + self.equation = equation + self.layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + self.matmul = P.MatMul(transpose_b=True) + self.sigmoid = nn.Sigmoid() + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + equation = ["ikc,jkc->ijc", "kjc,kic->ijc"] + if self.equation not in equation: + print("TriangleMultiplication Not Suppl") + if self.equation == "ikc,jkc->ijc": + self.equation = True + elif self.equation == "kjc,kic->ijc": + self.equation = False + else: + self.equation = None + self.batch_size = batch_size + self.layer_norm_dim = layer_norm_dim + self._init_parameter() + + def construct(self, act, mask, index=None): + r""" + Builds triangle multiplication module. + + Args: + act(Tensor): Pair activations. Data type is float. + mask(Tensor): Pair mask. Data type is float. + index(int): The index of the batch size when batch size is not none. + + Returns: + act(Tensor), the shape is same as act_shape[:-1]. + """ + + if self.batch_size: + layer_norm_input_gamma = P.Gather()(self.layer_norm_input_gammas, index, 0) + layer_norm_input_beta = P.Gather()(self.layer_norm_input_betas, index, 0) + left_projection_weight = P.Gather()(self.left_projection_weights, index, 0) + left_projection_bias = P.Gather()(self.left_projection_biases, index, 0) + right_projection_weight = P.Gather()(self.right_projection_weights, index, 0) + right_projection_bias = P.Gather()(self.right_projection_biases, index, 0) + left_gate_weight = P.Gather()(self.left_gate_weights, index, 0) + left_gate_bias = P.Gather()(self.left_gate_biases, index, 0) + right_gate_weight = P.Gather()(self.right_gate_weights, index, 0) + right_gate_bias = P.Gather()(self.right_gate_biases, index, 0) + center_layer_norm_gamma = P.Gather()(self.center_layer_norm_gammas, index, 0) + center_layer_norm_beta = P.Gather()(self.center_layer_norm_betas, index, 0) + output_projection_weight = P.Gather()(self.output_projection_weights, index, 0) + output_projection_bias = P.Gather()(self.output_projection_biases, index, 0) + gating_linear_weight = P.Gather()(self.gating_linear_weights, index, 0) + gating_linear_bias = P.Gather()(self.gating_linear_biases, index, 0) + else: + layer_norm_input_gamma = self.layer_norm_input_gammas + layer_norm_input_beta = self.layer_norm_input_betas + left_projection_weight = self.left_projection_weights + left_projection_bias = self.left_projection_biases + right_projection_weight = self.right_projection_weights + right_projection_bias = self.right_projection_biases + left_gate_weight = self.left_gate_weights + left_gate_bias = self.left_gate_biases + right_gate_weight = self.right_gate_weights + right_gate_bias = self.right_gate_biases + center_layer_norm_gamma = self.center_layer_norm_gammas + center_layer_norm_beta = self.center_layer_norm_betas + output_projection_weight = self.output_projection_weights + output_projection_bias = self.output_projection_biases + gating_linear_weight = self.gating_linear_weights + gating_linear_bias = self.gating_linear_biases + + mask = P.ExpandDims()(mask, -1) + act, _, _ = self.layer_norm(act, + layer_norm_input_gamma, + layer_norm_input_beta) + + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + out_shape = act_shape[:-1] + (-1,) + input_act = act + left_projection = P.BiasAdd()(self.matmul(act, left_projection_weight), left_projection_bias) + + left_gate_values = P.BiasAdd()(self.matmul(act, left_gate_weight), left_gate_bias) + left_gate_values = self.sigmoid(left_gate_values) + + left_proj_act = left_projection * left_gate_values + left_proj_act = P.Reshape()(left_proj_act, out_shape) + + right_projection = P.BiasAdd()(self.matmul(act, right_projection_weight), right_projection_bias) + + right_gate_values = P.BiasAdd()(self.matmul(act, right_gate_weight), right_gate_bias) + right_gate_values = self.sigmoid(right_gate_values) + + right_proj_act = mask * P.Reshape()(right_projection * right_gate_values, out_shape) + + if self.equation is not None: + if self.equation: + left_proj_act_tmp = P.Transpose()(left_proj_act, (2, 0, 1)) + right_proj_act_tmp = P.Transpose()(right_proj_act, (2, 0, 1)) + act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp) + act = P.Transpose()(act, (1, 2, 0)) + else: + left_proj_act_tmp = P.Transpose()(left_proj_act, (2, 1, 0)) + right_proj_act_tmp = P.Transpose()(right_proj_act, (2, 1, 0)) + act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp) + act = P.Transpose()(act, (2, 1, 0)) + + act, _, _ = self.layer_norm(act, + center_layer_norm_gamma, + center_layer_norm_beta) + + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + + act = P.BiasAdd()(self.matmul(act, output_projection_weight), output_projection_bias) + gate_values = P.BiasAdd()(self.matmul(input_act, gating_linear_weight), gating_linear_bias) + gate_values = self.sigmoid(gate_values) + + act = P.Reshape()(act * gate_values, out_shape) + return act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.layer_norm_input_gammas = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.layer_norm_input_betas = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.left_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel, self.layer_norm_dim)), + mstype.float32)) + self.left_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel)), mstype.float32)) + self.right_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel, self.layer_norm_dim)), + mstype.float32)) + self.right_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel)), mstype.float32)) + self.left_gate_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel, self.layer_norm_dim)), + mstype.float32)) + self.left_gate_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel)), mstype.float32)) + self.right_gate_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel, self.layer_norm_dim)), + mstype.float32)) + self.right_gate_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_intermediate_channel)), mstype.float32)) + self.center_layer_norm_gammas = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.center_layer_norm_betas = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.output_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) + self.output_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + self.gating_linear_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) + self.gating_linear_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) + else: + self.layer_norm_input_gammas = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.left_projection_weights = Parameter(initializer(lecun_init(self.num_intermediate_channel), + [self.num_intermediate_channel, + self.layer_norm_dim])) + self.left_projection_biases = Parameter( + Tensor(np.zeros((self.num_intermediate_channel)), mstype.float32)) + self.right_projection_weights = Parameter(initializer(lecun_init(self.num_intermediate_channel), + [self.num_intermediate_channel, + self.layer_norm_dim])) + self.right_projection_biases = Parameter( + Tensor(np.zeros((self.num_intermediate_channel)), mstype.float32)) + self.left_gate_weights = Parameter( + Tensor(np.zeros((self.num_intermediate_channel, self.layer_norm_dim)), mstype.float32)) + self.left_gate_biases = Parameter(Tensor(np.ones((self.num_intermediate_channel)), mstype.float32)) + self.right_gate_weights = Parameter( + Tensor(np.zeros((self.num_intermediate_channel, self.layer_norm_dim)), mstype.float32)) + self.right_gate_biases = Parameter(Tensor(np.ones((self.num_intermediate_channel)), mstype.float32)) + self.center_layer_norm_gammas = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + self.center_layer_norm_betas = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.output_projection_weights = Parameter( + Tensor(np.zeros((self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) + self.output_projection_biases = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.gating_linear_weights = Parameter( + Tensor(np.zeros((self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) + self.gating_linear_biases = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + + +class OuterProductMean(nn.Cell): + r""" + Computing the correlation of the input tensor along its second dimension, the computed correlation + could be used to update the correlation features(e.g. the Pair representation). + + .. math:: + OuterProductMean(\mathbf{act}) = Linear(flatten(mean(\mathbf{act}\otimes\mathbf{act}))) + + Args: + num_outer_channel (float): The last dimension size of intermediate layer in OuterProductMean. + act_dim (int): The last dimension size of the input act. + num_output_channel (int): The last dimension size of output. + batch_size(int): The batch size of parameters in OuterProductMean, + used in while control flow. Default: "None". + slice_num (int): The slice num used in OuterProductMean layer + when the memory is overflow. Default: 0. + + Inputs: + - **act** (Tensor) - The input tensor with shape :math:`(dim_1, dim_2, act\_dim)`. + - **mask** (Tensor) - The mask for OuterProductMean with shape :math:`(dim_1, dim_2)`. + - **mask_norm** (Tensor) - Squared L2-norm along the first dimension of **mask**, + pre-computed to avoid re-computing, its shape is :math:`(dim_2, dim_2, 1)`. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. Default: "None". + + Outputs: + Tensor, the float tensor of the output of OuterProductMean layer with + shape :math:`(dim_2, dim_2, num\_output\_channel)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import OuterProductMean + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> model = OuterProductMean(num_outer_channel=32, act_dim=128, num_output_channel=256) + >>> act = Tensor(np.ones((32, 64, 128)), mstype.float32) + >>> mask = Tensor(np.ones((32, 64)), mstype.float32) + >>> mask_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(mask, mask), -1) + >>> output= model(act, mask, mask_norm) + >>> print(output.shape) + (64, 64, 256) + """ + + def __init__(self, num_outer_channel, act_dim, num_output_channel, batch_size=None, slice_num=0): + super(OuterProductMean, self).__init__() + self.num_output_channel = num_output_channel + self.num_outer_channel = num_outer_channel + self.layer_norm_input = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.matmul = P.MatMul() + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.act_dim = act_dim + self.batch_size = batch_size + self.slice_num = slice_num + self.idx = Tensor(0, mstype.int32) + self._init_parameter() + + def construct(self, act, mask, mask_norm, index=None): + """Compute outer product mean.""" + + if self.batch_size: + layer_norm_input_gamma = P.Gather()(self.layer_norm_input_gammas, index, 0) + layer_norm_input_beta = P.Gather()(self.layer_norm_input_betas, index, 0) + left_projection_weight = P.Gather()(self.left_projection_weights, index, 0) + left_projection_bias = P.Gather()(self.left_projection_biases, index, 0) + right_projection_weight = P.Gather()(self.right_projection_weights, index, 0) + right_projection_bias = P.Gather()(self.right_projection_biases, index, 0) + linear_output_weight = P.Gather()(self.linear_output_weights, index, 0) + linear_output_bias = P.Gather()(self.o_biases, index, 0) + else: + layer_norm_input_gamma = self.layer_norm_input_gammas + layer_norm_input_beta = self.layer_norm_input_betas + left_projection_weight = self.left_projection_weights + left_projection_bias = self.left_projection_biases + right_projection_weight = self.right_projection_weights + right_projection_bias = self.right_projection_biases + linear_output_weight = self.linear_output_weights + linear_output_bias = self.o_biases + mask = P.ExpandDims()(mask, -1) + act, _, _ = self.layer_norm_input(act, layer_norm_input_gamma, layer_norm_input_beta) + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + out_shape = act_shape[:-1] + (-1,) + left_act = mask * P.Reshape()( + P.BiasAdd()(self.matmul_trans_b(act, left_projection_weight), left_projection_bias), out_shape) + right_act = mask * P.Reshape()( + P.BiasAdd()(self.matmul_trans_b(act, right_projection_weight), right_projection_bias), out_shape) + a, d, e = right_act.shape + right_act = P.Reshape()(right_act, (a, -1)) + batched_inputs = (left_act,) + nonbatched_inputs = (right_act, linear_output_weight, linear_output_bias, d, e) + act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num, 1) + epsilon = 1e-3 + act = P.RealDiv()(act, epsilon + mask_norm) + return act + + def _init_parameter(self): + '''init parameter''' + if self.batch_size: + self.layer_norm_input_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32)) + self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32)) + self.left_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_outer_channel, self.act_dim)), mstype.float32)) + self.left_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_outer_channel)), mstype.float32)) + self.right_projection_weights = Parameter( + Tensor(np.zeros((self.batch_size, self.num_outer_channel, self.act_dim)), mstype.float32)) + self.right_projection_biases = Parameter( + Tensor(np.zeros((self.batch_size, self.num_outer_channel)), mstype.float32)) + self.linear_output_weights = Parameter(Tensor(np.zeros( + (self.batch_size, self.num_output_channel, self.num_outer_channel * + self.num_outer_channel)), mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_output_channel)), mstype.float32)) + else: + self.layer_norm_input_gammas = Parameter(Tensor(np.ones((self.act_dim)), mstype.float32)) + self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.act_dim)), mstype.float32)) + self.left_projection_weights = Parameter( + initializer(lecun_init(self.act_dim), [self.num_outer_channel, self.act_dim])) + self.left_projection_biases = Parameter(Tensor(np.zeros((self.num_outer_channel)), mstype.float32)) + self.right_projection_weights = Parameter( + initializer(lecun_init(self.act_dim), [self.num_outer_channel, self.act_dim])) + self.right_projection_biases = Parameter(Tensor(np.zeros((self.num_outer_channel)), mstype.float32)) + self.linear_output_weights = Parameter( + Tensor(np.zeros((self.num_output_channel, self.num_outer_channel * self.num_outer_channel)), + mstype.float32)) + self.o_biases = Parameter(Tensor(np.zeros((self.num_output_channel)), mstype.float32)) + + def _compute(self, left_act, right_act, linear_output_weight, linear_output_bias, d, e): + '''compute outer product mean''' + + a, b, c = left_act.shape + left_act = P.Reshape()(P.Transpose()(left_act, (2, 1, 0)), (-1, a)) + act = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act, right_act), + (c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e)) + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + act = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act, linear_output_weight), + linear_output_bias), (d, b, -1)) + act = P.Transpose()(act, (1, 0, 2)) + return act diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e19e1f6a454239e60db2ab1ed272c8e1da21bc --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PSP""" + +from .psp import PSP +from .pdbbind import PDBBind +from .dataset import curry1, data_process_run, DataSet diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/dataset.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6259fe7eab978107d29a828e3d948f598c8989e1 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/dataset.py @@ -0,0 +1,64 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataset""" +from abc import ABCMeta, abstractmethod + + +def curry1(f): + """Supply all arguments but the first.""" + + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +def data_process_run(data, funcs): + for f in funcs: + data = f(data) + return data + + +class DataSet(metaclass=ABCMeta): + """DataSet""" + def __init__(self): + self.phase = None + + @abstractmethod + def __getitem__(self): + pass + + @abstractmethod + def __len__(self): + pass + + def set_phase(self, phase): + self.phase = phase + + @abstractmethod + def process(self, data, label=None): + pass + + @abstractmethod + def download(self, path=None): + pass + + @abstractmethod + def data_parse(self, input_data, idx): + pass + + @abstractmethod + def create_iterator(self, num_epochs): + pass diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/pdbbind/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/pdbbind/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d07caa3db0b7e8fc6bd4e511fede82eebbf6c953 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/pdbbind/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PDBBind""" + +from .pdbbind import PDBBind diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/pdbbind/pdbbind.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/pdbbind/pdbbind.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe1dcec9e086a748052b48b1422062197580a39 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/pdbbind/pdbbind.py @@ -0,0 +1,84 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PDBBind""" +import os +import tarfile +from tqdm import tqdm +from ..dataset import DataSet + + +class PDBBind(DataSet): + """"PDBBind Dataset""" + def __init__(self): + + self.url = { + "index": "http://www.pdbbind.org.cn/download/PDBbind_2016_plain_text_index.tar.gz", + "general": "http://www.pdbbind.org.cn/download/pdbbind_v2016_general-set-except-refined.tar.gz", + "refined": "http://www.pdbbind.org.cn/download/pdbbind_v2016_refined.tar.gz", + "pp": "http://www.pdbbind.org.cn/download/pdbbind_v2016_PP.tar.gz", + "mol2": "http://www.pdbbind.org.cn/download/PDBbind_v2016_mol2.tar.gz", + "sdf": "http://www.pdbbind.org.cn/download/PDBbind_v2016_sdf.tar.gz", + "2013": "http://www.pdbbind.org.cn/download/pdbbind_v2013_core_set.tar.gz" + } + + self.cache = "./PDBBind_data" + self.in_memory = True + super().__init__() + + def __getitem__(self, idx): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def download(self, path=None): + """download""" + if path is not None: + self.cache = path + print("Start download data") + for _, url in self.url.items(): + command = "wget -P " + self.cache + " " + url + os.system(command) + + file_list = os.listdir(path) + tar_gz_list = [] + for val in file_list: + if val.endswith("tar.gz"): + tar_gz_list.append(val) + + print("Start uncompression ... ") + for i in tqdm(range(len(tar_gz_list))): + val = tar_gz_list[i] + val_path = os.path.join(path, val) + if "PDBbind_2016_plain_text_index" in val: + dir_path = os.path.join(self.cache, "PDBbind_2016_plain_text_index/") + if not os.path.exists(dir_path): + os.makedirs(dir_path) + tar_file = tarfile.open(val) + tar_file.extractall(dir_path) + else: + tar_file = tarfile.open(val) + tar_file.extractall(val_path) + print("Finish uncompression ... ") + print("PDBBind has been saved in ", self.cache) + + def process(self, data, label=None): + raise NotImplementedError + + def data_parse(self, input_data, idx): + raise NotImplementedError + + def create_iterator(self, num_epochs): + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/psp/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/psp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31ff47d8756f2a6bb6ba4773018c6391bd06fae8 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/psp/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PSP""" + +from .psp import PSP diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/psp/psp.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/psp/psp.py new file mode 100644 index 0000000000000000000000000000000000000000..e9efc5f2fbad767eadf25b6de1ca2108dd12ae22 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/dataset/psp/psp.py @@ -0,0 +1,95 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PSP""" +import os +import tarfile +from tqdm import tqdm +from ..dataset import DataSet + + +def dir_walk(path, file_list): + files = os.listdir(path) + for file in files: + file_path = os.path.join(path, file) + if os.path.isdir(file_path): + dir_walk(file_path, file_list) + else: + file_list.append(file_path) + + +class PSP(DataSet): + """PSP DataSet""" + def __init__(self): + + self.url = { + "train": ["http://ftp.cbi.pku.edu.cn/psp/true_structure_dataset/", + "http://ftp.cbi.pku.edu.cn/psp/distillation_dataset/"], + "validation": ["http://ftp.cbi.pku.edu.cn/psp/new_validation_dataset/"], + "examples": ["https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/examples/"]} + + self.cache = "./psp_data/" + self.in_memory = False + super().__init__() + self.mode = ["train", "validation", "examples"] + + def __getitem__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def download(self, path=None, mode="validation"): + """download""" + if path is not None: + self.cache = path + # WARNING: just for linux OS + print("Start download data for mode : ", mode) + for url in self.url[mode]: + command = "wget -c -r -np -k -L -p -P " + self.cache + " " + url + os.system(command) + + file_list = [] + dir_walk(self.cache, file_list) + tar_gz_list = [] + for val in file_list: + if val.endswith("tar.gz"): + tar_gz_list.append(val) + + print("Start uncompression ... ") + for i in tqdm(range(len(tar_gz_list))): + val = tar_gz_list[i] + short_path, _ = os.path.split(val.split("/psp/")[-1]) + dir_path = os.path.join(self.cache, short_path) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + tar_file = tarfile.open(val) + tar_file.extractall(dir_path) + print("Finish uncompression ... ") + print("PSP DataSet has been saved in ", self.cache) + if mode == "train": + print("Make training name list") + self.make_name_list() + + def make_name_list(self): + pass + + def process(self): + raise NotImplementedError + + def data_parse(self, input, idx): + raise NotImplementedError + + def create_iterator(self): + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95b1405090b237ddac9fb3a3c80256a25f2a630d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Models""" +from .multimer import Multimer, MultimerDataSet, multimer_configuration +from .colabdesign import COLABDESIGN, ColabDesignDataSet, colabdesign_configuration diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeda7e6ab3edc6a6c78764f849baa1089035d1b2 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""colabdesign""" +from .colabdesign_dataset import ColabDesignDataSet +from .colabdesign_configuratuin import colabdesign_configuration +from .colabdesign import COLABDESIGN diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign.py new file mode 100644 index 0000000000000000000000000000000000000000..cfcc7d337884fd70ac692074ee8d33c8ba420571 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign.py @@ -0,0 +1,105 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""colabdesign""" +import numpy as np + +from mindspore import Parameter +from mindspore import Tensor, load_checkpoint +import mindspore as ms +from mindspore import jit, context + +from .nn_arch import Colabdesign +from ..model import Model +from .module.design_wrapcell import TrainOneStepCell, WithLossCell +from .module.utils import get_weights, get_lr, get_opt + + +class COLABDESIGN(Model): + """ColabDesign""" + name = "COLABDESIGN" + feature_list = ["msa_feat", "msa_mask", "seq_mask_batch", \ + "template_aatype", "template_all_atom_masks", "template_all_atom_positions", "template_mask", \ + "template_pseudo_beta_mask", "template_pseudo_beta", \ + "extra_msa", "extra_has_deletion", "extra_deletion_value", "extra_msa_mask", \ + "residx_atom37_to_atom14", "atom37_atom_exists_batch", \ + "residue_index_batch", "batch_aatype", "batch_all_atom_positions", "batch_all_atom_mask", + "opt_temp", \ + "opt_soft", "opt_hard", "prev_pos", "prev_msa_first_row", "prev_pair"] + + def __init__(self, config): + context.set_context(memory_optimize_level="O1", max_call_depth=6000) + if context.get_context("device_target") == "GPU": + self.mixed_precision = False + context.set_context(graph_kernel_flags="--disable_expand_ops=Softmax --disable_cluster_ops=ReduceSum " + "--composite_op_limit_size=50", enable_graph_kernel=True) + else: + self.mixed_precision = True + + self.config = config + self.use_jit = self.config.use_jit + self.checkpoint_url = \ + 'https://download.mindspore.cn/mindscience/mindsponge/Multimer/checkpoint/Multimer_Model_1.ckpt' + self.checkpoint_path = "./colabdesign.ckpt" + seq_vector = 0.01 * np.random.normal(0, 1, size=(1, 100, 20)) + self.network = Colabdesign(self.config, self.mixed_precision, Tensor(seq_vector, ms.float16), 100, + protocol=self.config.protocol) + load_checkpoint(self.checkpoint_path, self.network) + net_with_criterion = WithLossCell(self.network) + soft_weights, temp_weights = get_weights(self.config, self.config.soft_iters, self.config.temp_iters, + self.config.hard_iters) + epoch = self.config.soft_iters + self.config.temp_iters + self.config.hard_iters + lr = get_lr(temp_weights, soft_weights, epoch) + model_params = [Parameter(Tensor(seq_vector, ms.float16))] + opt = get_opt(model_params, lr, 0.0, self.config.opt_choice) + self.train_net = TrainOneStepCell(net_with_criterion, opt, sens=8192) + super().__init__(self.checkpoint_url, self.network, self.name) + + # pylint: disable=arguments-differ + def predict(self, data): + pass + + def forward(self, data): + pass + + # pylint: disable=arguments-differ + @jit + def backward(self, feat): + loss = self.train_net(*feat) + return loss + + # pylint: disable=arguments-differ + def train_step(self, data): + features = [] + for feature in data: + features.append(Tensor(data[feature])) + + loss = self.backward(features) + + return loss + + def _pynative_forward(self, data): + pass + + @jit + def _jit_forward(self, data): + pass diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_configuratuin.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_configuratuin.py new file mode 100644 index 0000000000000000000000000000000000000000..db679a98e7f8ad8bf937746a2a6f78e484745bab --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_configuratuin.py @@ -0,0 +1,26 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""colabdesign_configuration""" +colabdesign_configuration = { + "fold_design": "https://download.mindspore.cn/mindscience/mindsponge/Multimer/config/" +} diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_data.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_data.py new file mode 100644 index 0000000000000000000000000000000000000000..57c0c41d501305f60d262ad5a895eea6ba9cbb42 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_data.py @@ -0,0 +1,135 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""colabdesign data""" +import numpy as np +import mindsponge.common.residue_constants as residue_constants + +from ...dataset import curry1 +from ....common import residue_constants + + +@curry1 +def dict_filter_key(feature, feature_list): + feature = {k: v for k, v in feature.items() if k in feature_list} + return feature + + +@curry1 +def dict_replace_key(feature, replaced_key): + assert len(replaced_key) == 2 + origin_key, new_key = replaced_key + if origin_key in feature: + feature[new_key] = feature.pop(origin_key) + return feature + + +@curry1 +def dict_cast(feature, cast_type, filtered_list): + assert len(cast_type) == 2 + origin_type = cast_type[0] + new_type = cast_type[1] + for k, v in feature.items(): + if k not in filtered_list: + if v.dtype == origin_type: + feature[k] = v.astype(new_type) + return feature + + +@curry1 +def dict_suqeeze(feature=None, filter_list=None, axis=None): + for k in filter_list: + if k in feature: + feat_dim = feature[k].shape[axis] + if isinstance(feat_dim, int) and feat_dim == 1: + feature[k] = np.squeeze(feature[k], axis=axis) + return feature + + +@curry1 +def dict_take(feature, filter_list, axis): + for k in filter_list: + if k in feature: + feature[k] = feature[k][axis] + return feature + + +@curry1 +def dict_del_key(feature, filter_list): + for k in filter_list: + if k in feature: + del feature[k] + return feature + + +@curry1 +def one_hot_convert(feature, key, axis): + if key in feature: + feature[key] = np.argmax(feature[key], axis=axis) + return feature + + +@curry1 +def correct_restypes(feature, key): + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = np.array(new_order_list, dtype=feature[key].dtype) + feature[key] = new_order[feature[key]] + return feature + + +@curry1 +def prep(feature=None, cfg=None): + prev_pos = np.zeros((cfg.seq_length, 37, 3)).astype(np.float32) + prev_msa_first_row = np.zeros((cfg.seq_length, cfg.model.msa_channel)).astype(np.float32) + prev_pair = np.zeros((cfg.seq_length, cfg.seq_length, cfg.model.pair_channel)).astype(np.float32) + feature.append(prev_pos) + feature.append(prev_msa_first_row) + feature.append(prev_pair) + return feature + + +@curry1 +def get_weights(feature=None, index=None, cfg=None): + """get weights""" + opt_temp = [] + opt_soft = [] + opt_hard = [] + + for i in range(cfg.soft_iters): + opt_temp.append( + cfg.soft_etemp + (cfg.soft_temp - cfg.soft_etemp) * (1 - (i + 1) / cfg.soft_iters) ** 2) + opt_soft.append((i + 1) / cfg.soft_iters) + opt_hard.append(cfg.soft_hard) + for i in range(cfg.temp_iters): + opt_temp.append( + cfg.temp_decay + (cfg.temp_value - cfg.temp_decay) * (1 - (i + 1) / cfg.temp_iters) ** 2) + opt_soft.append(cfg.temp_esoft + (cfg.temp_soft - cfg.temp_esoft) * ((i + 1) / cfg.temp_iters)) + opt_hard.append(cfg.temp_ehard + (cfg.temp_hard - cfg.temp_ehard) * ((i + 1) / cfg.temp_iters)) + for i in range(cfg.hard_iters): + opt_temp.append( + cfg.hard_etemp + (cfg.hard_temp - cfg.hard_etemp) * (1 - (i + 1) / cfg.hard_iters) ** 2) + opt_soft.append(cfg.hard_esoft + (cfg.hard_soft - cfg.hard_esoft) * ((i + 1) / cfg.hard_iters)) + opt_hard.append(cfg.hard_decay + (cfg.hard_value - cfg.hard_decay) * ((i + 1) / cfg.hard_iters)) + feature.append(opt_temp[index]) + feature.append(opt_soft[index]) + feature.append(opt_hard[index]) + return feature diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_dataset.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2467b682138f79d1c1919b99cb846707f12665d5 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/colabdesign_dataset.py @@ -0,0 +1,101 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""colabdesign dataset""" +import os +import pickle + +from mindspore.dataset import GeneratorDataset + +from ...dataset import PSP, data_process_run +from .colabdesign_data import prep, get_weights + + +class ColabDesignDataSet(PSP): + """ColabDesignDataSet""" + + def __init__(self, config, num_seq=1): + self.config = config + self.supported_models = ['ColabDesign'] + self.in_memory = False + self.colabdesign_inputs() + self.indx = 0 + self.training_data_src = "" + self.training_pkl_path = "" + self.training_pdb_path = "" + self.training_pdb_items = "" + self.training_pkl_items = "" + self.data_process = [get_weights(self.indx, cfg=config), prep(cfg=config)] + + self._num = num_seq + super().__init__() + + def __getitem__(self, idx): + if self.in_memory: + data = self.inputs[idx] + else: + data = self.data_parse(idx) + + self.indx += 1 + features = self.process(data) + return tuple(features) + + def __len__(self): + data_len = len(os.listdir(self.training_pdb_path)) + return data_len + + def colabdesign_inputs(self): + feature_list = ["msa_feat", "msa_mask", "seq_mask_batch", \ + "template_aatype", "template_all_atom_masks", "template_all_atom_positions", "template_mask", \ + "template_pseudo_beta_mask", "template_pseudo_beta", \ + "extra_msa", "extra_has_deletion", "extra_deletion_value", "extra_msa_mask", \ + "residx_atom37_to_atom14", "atom37_atom_exists_batch", \ + "residue_index_batch", "batch_aatype", "batch_all_atom_positions", "batch_all_atom_mask", + "opt_temp", \ + "opt_soft", "opt_hard", "prev_pos", "prev_msa_first_row", "prev_pair"] + self.feature_list = feature_list + + # pylint: disable=arguments-differ + def data_parse(self, idx): + pkl_path = self.training_pkl_items[idx] + f = open(pkl_path, "rb") + data = pickle.load(f) + return data + + # pylint: disable=arguments-differ + def process(self, data): + features = data_process_run(data.copy(), self.data_process) + return features + + def set_training_data_src(self, data_src): + self.training_data_src = data_src + self.training_pkl_path = self.training_data_src + "/pkl/" + self.training_pdb_path = self.training_data_src + "/pdb/" + self.training_pdb_items = [self.training_pdb_path + key for key in sorted(os.listdir(self.training_pdb_path))] + self.training_pkl_items = [self.training_pkl_path + key for key in sorted(os.listdir(self.training_pkl_path))] + + # pylint: disable=arguments-differ + def create_iterator(self, num_epochs): + dataset = GeneratorDataset(source=self, column_names=self.feature_list, num_parallel_workers=4, shuffle=False, + max_rowsize=16) + iteration = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True) + return iteration diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d27dd78d05d135f5c629cc6a40a9e8c96ae6cae --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""module""" diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/design_wrapcell.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/design_wrapcell.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f4f1ac16c2d49dee06ba309d8926d5ab70299f --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/design_wrapcell.py @@ -0,0 +1,130 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""design wrapcell""" +import mindspore.ops as ops +import mindspore.common.dtype as mstype +from mindspore import nn +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, _get_parallel_mode) +from mindspore.context import ParallelMode + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = float(0.001) + +clip_grad = ops.MultitypeFuncGraph("clip_grad") + + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """clip grad""" + if clip_type not in (0, 1): + return grad + dt = ops.dtype(grad) + if clip_type == 0: + new_grad = ops.clip_by_value(grad, ops.cast(ops.tuple_to_array((-clip_value,)), dt), + ops.cast(ops.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt)) + return new_grad + + +grad_scale = C.MultitypeFuncGraph("grad_scale") + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + """grad scale""" + return grad * ops.Reciprocal()(scale) + + +grad_mul = C.MultitypeFuncGraph("grad_mul") + + +@grad_mul.register("Tuple", "Tensor") +def tensor_grad_mul(x, y): + """grad mul""" + return x * y + + +grad_square = C.MultitypeFuncGraph("grad_square") + + +@grad_square.register("Tensor") +def tensor_grad_square(x): + """grad square""" + x_temp = ops.Square()(x).astype(mstype.float32) + x_square = ((x_temp.sum(-1, keepdims=True) > 0).astype(mstype.float32)) + x_square = x_square.sum(-2, keepdims=True).astype(mstype.float32) + x_sqrt = ops.Sqrt()(x_square).astype(mstype.float32) + x_final = ops.div(x_sqrt, GRADIENT_CLIP_VALUE) + return x_final[0][0][0] + + +class TrainOneStepCell(nn.Cell): + """TrainOneStepCell""" + + def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True, use_global_norm=True): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.optimizer = optimizer + self.weights = self.optimizer.parameters + self.grad = ops.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.enable_clip_grad = enable_clip_grad + self.hyper_map = ops.HyperMap() + self.use_global_norm = use_global_norm + + self.grad_reducer = F.identity + self.parallel_mode = _get_parallel_mode() + self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) + if self.reducer_flag: + self.mean = _get_gradients_mean() + self.degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) + + def construct(self, *inputs): + """construct""" + loss = self.network( + *inputs) + sens = F.fill(loss.dtype, loss.shape, self.sens) + grads = self.grad(self.network, self.weights)(*inputs, ( + sens)) + grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_tensor(self.sens)), grads) + if self.enable_clip_grad: + if self.use_global_norm: + eff_len = self.hyper_map(grad_square, grads) + grads = C.clip_by_global_norm(grads, GRADIENT_CLIP_VALUE) + grads = self.hyper_map(ops.partial(grad_mul, eff_len), grads) + else: + grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + grads = self.grad_reducer(grads) + + loss = F.depend(loss, self.optimizer(grads)) + return loss + + +class WithLossCell(nn.Cell): + """WithLossCell""" + + def __init__(self, backbone): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + + def construct(self, *inputs): + """construct""" + out = self._backbone(*inputs) + return out diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/loss_design.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/loss_design.py new file mode 100644 index 0000000000000000000000000000000000000000..f722617e23fafdaa077bf198e98a4f38bddcbf2e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/loss_design.py @@ -0,0 +1,410 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""design loss""" +import numpy as np +import mindspore as ms +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops import operations as P +import mindsponge.common.residue_constants as residue_constants +from mindsponge.common.utils import pseudo_beta_fn + + +class SVD(nn.Cell): + """SVD""" + + def __init__(self): + super(SVD, self).__init__() + self.matmul_a = P.MatMul(transpose_a=True) + + def qr_split(self, ori_matrix): + """QR Split""" + shapes = ori_matrix.shape[0] + quadrature_matrix = [[]] + res = ori_matrix + for i in range(0, shapes - 1): + batch = res + if i != 0: + batch = batch[i:, i:] + x = batch[:, 0] + m = mnp.norm(x) + y = [0 for j in range(0, shapes - i)] + y[0] = m + w = x - y + w = w / mnp.norm(w) + h = mnp.eye(shapes - i) - 2 * P.MatMul()(w.reshape(shapes - i, 1), w.reshape(1, shapes - i)) + if i == 0: + quadrature_matrix = h + res = P.MatMul()(h, res) + else: + dim = mnp.concatenate((mnp.eye(i), mnp.zeros((i, shapes - i))), axis=1) + h = mnp.concatenate((mnp.zeros((shapes - i, i)), h), axis=1) + h = mnp.concatenate((dim, h), axis=0) + quadrature_matrix = P.MatMul()(h, quadrature_matrix) + res = P.MatMul()(h, res) + quadrature_matrix = quadrature_matrix.T + return [quadrature_matrix, res] + + def qr_egis(self, ori_matrix): + """QR egis""" + qr = [] + shapes = ori_matrix.shape[0] + quadrature_matrix = mnp.eye(shapes) + for i in range(0, 100): + qr = self.qr_split(ori_matrix) + quadrature_matrix = P.MatMul()(quadrature_matrix, qr[0]) + ori_matrix = P.MatMul()(qr[1], qr[0]) + + ak = P.MatMul()(qr[0], qr[1]) + e = P.Ones()((3, 1), mstype.float32) + for i in range(0, shapes): + e[i] = ak[i][i] + return e, quadrature_matrix + + def rebuild_matrix(self, u, sigma, v): + """rebuild matrix""" + a = P.MatMul()(u, sigma) + a = P.MatMul()(a, np.transpose(v)) + return a + + def sort_eigenvalue(self, eigenvalues, eigenvectors): + """sort_eigenvalue""" + _, index = P.Sort(axis=0)(-1 * eigenvalues) + eigenvalues = eigenvalues[index] + eigenvectors = eigenvectors[:, index] + return eigenvalues, eigenvectors + + def svd(self, matrixa, numofleft=None): + """Singular value decomposition of a matrix""" + matrixat_matrixa = self.matmul_a(matrixa, matrixa) + lambda_v, x_v = self.qr_egis(matrixat_matrixa) + lambda_v, x_v = self.sort_eigenvalue(lambda_v, x_v) + sigmas = lambda_v + + sigmas_new = mnp.where(sigmas > 0, sigmas, 0) + sigmas = P.Sqrt()(sigmas_new) + + sigmas = mnp.concatenate(sigmas) + sigmasmatrix = mnp.diag(sigmas[:, 0]) + if numofleft is None: + rankofsigmasmatrix = 3 + else: + rankofsigmasmatrix = numofleft + sigmasmatrix = sigmasmatrix[0:rankofsigmasmatrix, :] + + x_u = mnp.zeros((matrixa.shape[0], rankofsigmasmatrix)) + for i in range(rankofsigmasmatrix): + x_u[:, i] = (P.MatMul()(matrixa, x_v[:, i]) / sigmas[i])[:, 0] + + x_v = mnp.squeeze(x_v[:, 0:numofleft]) + sigmasmatrix = sigmasmatrix[0:rankofsigmasmatrix, 0:rankofsigmasmatrix] + return x_u, mnp.diag(sigmasmatrix), x_v + + +class LossNet(nn.Cell): + "loss net" + + def __init__(self, design_cfg, protocol): + super(LossNet, self).__init__() + self.mul = P.Mul() + self.expand_dims = P.ExpandDims() + self.batch_matmul = P.BatchMatMul() + self.matmul_a = P.MatMul(transpose_a=True) + self.matmul = P.MatMul() + self.svd = SVD() + if protocol == 'fixbb': + loss_weights = design_cfg.fixbb + elif protocol == 'hallucination': + loss_weights = design_cfg.hallu + self.con_weights = loss_weights.con + self.plddt_weights = loss_weights.plddt + self.rmsd_weights = loss_weights.rmsd + self.seq_weights = loss_weights.seq + self.dgram_weights = loss_weights.dgram + self.fape_weights = loss_weights.fape + self.pae_weights = loss_weights.pae + self.exp_weights = loss_weights.exp + self.rg_weights = loss_weights.rg + + def get_fape_loss(self, true_all_atom_positions, true_all_atom_mask, final_atom_positions, clamp=10.0, + return_mtx=False): + "fape loss" + + def robust_norm(x, axis=-1, keepdims=False, eps=1e-8): + return P.Sqrt()(P.Square()(x).sum(axis=axis, keepdims=keepdims) + eps) + + def get_r(n, ca, cinput): + (v1, v2) = (cinput - ca, n - ca) + e1 = v1 / robust_norm(v1, axis=-1, keepdims=True) + c1 = self.mul(e1, v2).sum(axis=1) + c = self.expand_dims(c1, 1) + e2 = v2 - c * e1 + e2 = e2 / robust_norm(e2, axis=-1, keepdims=True) + e3 = mnp.cross(e1, e2, axis=-1) + e1 = self.expand_dims(e1, 2) + e2 = self.expand_dims(e2, 2) + e3 = self.expand_dims(e3, 2) + return mnp.concatenate([e1, e2, e3], axis=-1) + + def get_ij(r, t): + t = self.expand_dims(t, 0) - self.expand_dims(t, 1) + return self.batch_matmul(t, r) + + def loss_fn(t, p, m): + fape = robust_norm(t - p) + fape = mnp.clip(fape, 0, clamp) / 10.0 + return fape, (fape * m).sum((-1, -2)) / (m.sum((-1, -2)) + 1e-8) + + true = true_all_atom_positions + pred = final_atom_positions + + n, ca, cinput = (residue_constants.atom_order[k] for k in ["N", "CA", "C"]) + + true_mask = true_all_atom_mask + weights = true_mask[:, n] * true_mask[:, ca] * true_mask[:, cinput] + + true = get_ij(get_r(true[:, n], true[:, ca], true[:, cinput]), true[:, ca]) + pred = get_ij(get_r(pred[:, n], pred[:, ca], pred[:, cinput]), pred[:, ca]) + + return self._get_pw_loss(true, pred, loss_fn, weights=weights, return_mtx=return_mtx) + + def get_rmsd_loss(self, true_all_atom_positions, true_all_atom_mask, true_final_atom_positions): + """rmsd loss""" + true = true_all_atom_positions[:, 1] + pred = true_final_atom_positions[:, 1] + weights = true_all_atom_mask[:, 1] + return self._get_rmsd_loss(true, pred, weights=weights) + + def get_dgram_loss(self, batch_aatype, batch_all_atom, batch_all_atom_mask, dist_logits, aatype=None, + return_mtx=False): + """dgram_loss""" + + if aatype is None: + aatype = batch_aatype + + pred = dist_logits + x, weights = pseudo_beta_fn(aatype=aatype, + all_atom_positions=batch_all_atom, + all_atom_masks=batch_all_atom_mask) + # + dm = mnp.square(x[:, None] - x[None, :]).sum(-1, keepdims=True).astype(ms.float32) + bin_edges = mnp.linspace(2.3125, 21.6875, pred.shape[-1] - 1) + hot_value = (dm > mnp.square(bin_edges)).astype(ms.float32) + hot_value = hot_value.sum(-1).astype(ms.int32) + one_hot = nn.OneHot(depth=pred.shape[-1]) + true_label = one_hot(hot_value).astype(ms.float32) + + def loss_fn(t, p, m): + cce = -(t * ms.ops.log_softmax(p)).sum(-1) + return cce, (cce * m).sum((-1, -2)) / (m.sum((-1, -2)) + 1e-8) + + return self._get_pw_loss(true_label, pred, loss_fn, weights=weights, return_mtx=return_mtx) + + def get_seq_ent_loss(self, inputs): + """seq_ent loss""" + softmax = ms.nn.Softmax() + x = inputs / mnp.array(1.) + ent = -(softmax(x) * ms.ops.log_softmax(x)).sum(-1) + mask = mnp.ones(ent.shape[-1]) + + ent = (ent * mask).sum() / (mask.sum() + 1e-8) + return ent.mean() + + def mask_loss(self, x, mask=None, mask_grad=False): + """mask_loss""" + if mask is None: + result = x.mean() + else: + x_masked = (x * mask).sum() / (1e-8 + mask.sum()) + if mask_grad: + result = ms.ops.stop_gradient(x.mean() - x_masked) + x_masked + else: + result = x_masked + return result + + def get_exp_res_loss(self, outputs, mask_1d=None): + """exp_res loss""" + + sigmoid = ms.nn.Sigmoid() + p = sigmoid(outputs) + p = 1 - p[..., residue_constants.atom_order["CA"]] + return self.mask_loss(p, mask_1d) + + def get_plddt_loss(self, outputs, mask_1d=None): + """plddt loss""" + softmax = ms.nn.Softmax() + p = softmax(outputs) + op = ops.ReverseV2(axis=[-1]) + p = (p * op(mnp.arange(p.shape[-1]))).mean(-1) + + return self.mask_loss(p, mask_1d) + + def get_pae_loss(self, outputs, mask_1d=None, mask_1b=None, mask_2d=None): + """pae loss""" + # aligned error logits + softmax = ms.nn.Softmax() + p = softmax(outputs) + p = (p * mnp.arange(p.shape[-1])).mean(-1) + p = (p + p.T) / 2 + leng = p.shape[0] + if mask_1d is None: + mask_1d = mnp.ones(leng) + if mask_1b is None: + mask_1b = mnp.ones(leng) + if mask_2d is None: + mask_2d = mnp.ones((leng, leng)) + mask_2d = mask_2d * mask_1d[:, None] * mask_1b[None, :] + return self.mask_loss(p, mask_2d) + + def get_con_loss(self, residue_index, loss_dgram_logits, loss_dgram_bin, + mask_1d=None, mask_1b=None, mask_2d=None): + """con loss""" + + # get top k + def min_k(x, k=1, mask=None): + sort = ops.Sort() + y = sort(x if mask is None else mnp.where(mask, x, Tensor(65504, dtype=ms.float32)))[0].astype(ms.float32) + nan_mask = mnp.where(y != Tensor(65504, dtype=ms.float32), False, True) + k_mask = mnp.logical_and(mnp.arange(y.shape[-1]) < k, nan_mask == Tensor(False)).astype(ms.float32) + return mnp.where(k_mask, y, Tensor(0)).sum(-1) / (k_mask.sum(-1) + 1e-8) + + def _get_con_loss(dgram, dgram_bins, cutoff=None, binary=True): + """dgram to contacts""" + if cutoff is None: + cutoff = dgram_bins[-1] + softmax = ms.nn.Softmax() + bins = dgram_bins < cutoff + px = softmax(dgram) + px_ = softmax(dgram - 1e7 * (1 - bins)) + # binary/cateogorical cross-entropy + con_loss_cat_ent = -(px_ * ms.ops.log_softmax(dgram)).sum(-1) + con_loss_bin_ent = -mnp.log((bins * px + 1e-8).sum(-1)) + return mnp.where(binary, con_loss_bin_ent, con_loss_cat_ent) + + idx = residue_index.flatten() + offset = idx[:, None] - idx[None, :] + # # # define distogram + dgram = loss_dgram_logits + dgram_bins = mnp.append(Tensor(0), loss_dgram_bin) + p = _get_con_loss(dgram, dgram_bins, cutoff=mnp.array(14.), binary=mnp.array(False)) + + m = mnp.abs(offset) >= mnp.array(9) + + if mask_1d is None: + mask_1d = mnp.ones(m.shape[0], dtype=bool) + if mask_1b is None: + mask_1b = mnp.ones(m.shape[0], dtype=bool) + # + if mask_2d is None: + m = mnp.logical_and(m, mnp.array(mask_1b)) + else: + m = mnp.logical_and(m, mnp.array(mask_2d)) + + p = min_k(p, mnp.array(2), m) + + return min_k(p, mnp.array(mnp.inf), mask_1d) + + def rg_loss(self, final_atom_positions): + positions = final_atom_positions + ca = positions[:, residue_constants.atom_order["CA"]] + center = ca.mean(0) + rg = mnp.sqrt(mnp.square(ca - center).sum(-1).mean() + 1e-8) + rg_th = 2.38 * ca.shape[0] ** 0.365 + rg = ms.nn.ELU()(rg - rg_th) + return rg + + def construct(self, true_aatype, true_all_atom_positions, true_all_atom_mask, true_final_atom_positions, + ori_seq_len, dist_logits, bin_edges, experimentally_logits, predicted_lddt_logits, + aligned_error_logits, residue_index, seq_logits): + """construct""" + mask_1d = mnp.ones((ori_seq_len,)) + mask_2d = (mask_1d[:, None] == mask_1d[None, :]) + masks = {"mask_1d": mask_1d, + "mask_2d": mask_2d} + fape_loss = self.get_fape_loss(true_all_atom_positions[:ori_seq_len, :, :], true_all_atom_mask[:ori_seq_len, :], + true_final_atom_positions[:ori_seq_len, :, :]) + dgram_cce = self.get_dgram_loss(true_aatype[:ori_seq_len], true_all_atom_positions[:ori_seq_len, :, :], + true_all_atom_mask[:ori_seq_len, :], dist_logits[:ori_seq_len, :ori_seq_len, :]) + exp_res = self.get_exp_res_loss(experimentally_logits[:ori_seq_len, :], mask_1d=mask_1d) + plddt = self.get_plddt_loss(predicted_lddt_logits[:ori_seq_len, :], mask_1d=mask_1d) + pae = self.get_pae_loss(aligned_error_logits[:ori_seq_len, :ori_seq_len, :], **masks) + con = self.get_con_loss(residue_index[:ori_seq_len], dist_logits[:ori_seq_len, :ori_seq_len, :], bin_edges, + **masks) + rg_loss = self.rg_loss(true_final_atom_positions) + seq_loss = self.get_seq_ent_loss(seq_logits[:, :ori_seq_len, :]) + rmsd_loss = fape_loss + if self.rmsd_weights: + rmsd_loss = self.get_rmsd_loss(true_all_atom_positions[:ori_seq_len, :, :], + true_all_atom_mask[:ori_seq_len, :], + true_final_atom_positions[:ori_seq_len, :, :]) + + loss_all = con * self.con_weights + exp_res * self.exp_weights + self.plddt_weights * plddt + \ + self.seq_weights * seq_loss + self.pae_weights * pae + fape_loss * self.fape_weights + \ + self.dgram_weights * dgram_cce + rmsd_loss * self.rmsd_weights + rg_loss * self.rg_weights + return loss_all + + def _get_rmsd_loss(self, true, pred, weights=None): + """ + get rmsd + alignment function + align based on the first L positions, computed weighted rmsd using all + positions (if include_l=True) or remaining positions (if include_l=False). + """ + # normalize weights + length = true.shape[-2] + if weights is None: + weights = (mnp.ones(length) / length)[..., None] + else: + weights = (weights / (weights.sum(-1, keepdims=True) + 1e-8))[..., None] + + (t_fixbb, p_fixbb, w_fixbb) = (true, pred, weights) + + (t_mu, p_mu) = ((x * w_fixbb).sum(-2, keepdims=True) / w_fixbb.sum((-1, -2)) for x in (t_fixbb, p_fixbb)) + aln = self._np_kabsch((p_fixbb - p_mu) * w_fixbb, t_fixbb - t_mu) + + align_value = P.MatMul()(pred - p_mu, aln) + t_mu + msd_scalar = (weights * mnp.square(align_value - true)).sum((-1, -2)) + rmsd = P.Sqrt()(msd_scalar + 1e-8) + + return rmsd + + def _np_kabsch(self, a, b): + """get alignment matrix for two sets of coordinates""" + ab = self.matmul_a(a, b) + + u, _, vh = self.svd.svd(ab) + flip = self._det(self.matmul(u, vh)) < 0 + u_ = mnp.where(flip, -u[..., -1].T, u[..., -1].T).T + u[..., -1] = u_ + return self.matmul(u, vh) + + def _det(self, matrix): + """det""" + # matrix dim=3 + result = matrix[0, 0] * matrix[1, 1] * matrix[2, 2] + matrix[0, 1] * matrix[1, 2] * matrix[2, 0] + \ + matrix[0, 2] * matrix[1, 0] * matrix[2, 1] - matrix[0, 2] * matrix[1, 1] * matrix[2, 0] - \ + matrix[0, 1] * matrix[1, 0] * matrix[2, 2] - matrix[0, 0] * matrix[1, 2] * matrix[2, 1] + return result + + def _get_pw_loss(self, true, pred, loss_fn, weights=None, return_mtx=False): + """get pw loss""" + + expand_dims = ops.ExpandDims() + fs = {"t": true, "p": pred, "m": expand_dims(weights, 1) * expand_dims(weights, 0)} + + mtx, loss = loss_fn(**fs) + return mtx if return_mtx else loss diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/utils.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c418ab46d262d5da9af7505e7789531cb4d0eadf --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/module/utils.py @@ -0,0 +1,67 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate""" +import numpy as np +import mindspore.nn as nn +import mindsponge.common.residue_constants as residue_constants + + +def get_weights(config, soft_iters, temp_iters, hard_iters): + """get weights""" + opt_temp = [] + opt_soft = [] + opt_hard = [] + + for i in range(soft_iters): + opt_temp.append( + config.soft_etemp + (config.soft_temp - config.soft_etemp) * (1 - (i + 1) / soft_iters) ** 2) + opt_soft.append((i + 1) / soft_iters) + opt_hard.append(config.soft_hard) + for i in range(temp_iters): + opt_temp.append( + config.temp_decay + (config.temp_value - config.temp_decay) * (1 - (i + 1) / temp_iters) ** 2) + opt_soft.append(config.temp_esoft + (config.temp_soft - config.temp_esoft) * ((i + 1) / temp_iters)) + opt_hard.append(config.temp_ehard + (config.temp_hard - config.temp_ehard) * ((i + 1) / temp_iters)) + for i in range(hard_iters): + opt_temp.append( + config.hard_etemp + (config.hard_temp - config.hard_etemp) * (1 - (i + 1) / hard_iters) ** 2) + opt_soft.append(config.hard_esoft + (config.hard_soft - config.hard_esoft) * ((i + 1) / hard_iters)) + opt_hard.append(config.hard_decay + (config.hard_value - config.hard_decay) * ((i + 1) / hard_iters)) + return opt_temp, opt_hard + + +def get_lr(opt_temps, opt_softs, epoch, lr=0.1): + """get leraning_rate""" + lr_each_step = [] + for i in range(epoch): + lr_each_step.append(lr * ((1 - opt_softs[i]) + (opt_softs[i] * opt_temps[i]))) + lr_each_step = np.array(lr_each_step).astype(np.float32) + return lr_each_step + + +def get_opt(model_params, lr, weight_decay, choice): + """get opt""" + if choice == 'sgd': + opt = nn.SGD(model_params, lr, weight_decay) + elif choice == 'adam': + opt = nn.Adam(model_params, lr, weight_decay) + return opt + + +def get_seqs(seq_hard): + aa_order = residue_constants.restype_order + order_aa = {b: a for a, b in aa_order.items()} + x = seq_hard.argmax(-1) + return ["".join(order_aa[a] for a in s) for s in x] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/nn_arch.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/nn_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc986cb2242c37ee4a3cfd8bf04d2865c5725e9 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/colabdesign/nn_arch.py @@ -0,0 +1,141 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""design_fold""" +import numpy as np +from scipy.special import softmax + +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore import ops + +from .module.loss_design import LossNet + + +def compute_confidence(predicted_lddt_logits, return_lddt=False): + """compute confidence""" + + num_bins = predicted_lddt_logits.shape[-1] + bin_width = 1 / num_bins + start_n = bin_width / 2 + plddt = compute_plddt(predicted_lddt_logits, start_n, bin_width) + confidence = np.mean(plddt) + if return_lddt: + return confidence, plddt + + return confidence + + +def compute_plddt(logits, start_n, bin_width): + """Computes per-residue pLDDT from logits. + + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + + Returns: + plddt: [num_res] per-residue pLDDT. + """ + bin_centers = np.arange(start=start_n, stop=1.0, step=bin_width) + probs = softmax(logits, axis=-1) + predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) + return predicted_lddt_ca * 100 + + +class Colabdesign(nn.Cell): + """Colabdesign""" + + def __init__(self, config, mixed_precision, seq_vector, ori_seq_len, protocol): + super(Colabdesign, self).__init__() + self.megafold = MegaFold(config, mixed_precision) + self.megafold.add_flags_recursive(train_backward=True) + self.cfg = config + self.seq_vector = seq_vector + self.ori_seq_len = ori_seq_len + self.crop_size = config.seq_length + self.opt_alpha = Tensor(config.opt_alpha, mstype.float32) + self.opt_bias = Tensor(config.opt_bias, mstype.float32) + self.opt_use_pssm = config.opt_use_pssm + self.loss_net = LossNet(config, protocol) + + def soft_seq(self, x, ori_seq_len, opt_temp_num, opt_soft_num, opt_hard_num): + """soft_seq""" + seq_input = x[:, :ori_seq_len, :] + seq_logits = seq_input * self.opt_alpha + self.opt_bias + seq_pssm = P.Softmax()(seq_logits) + seq_soft = P.Softmax()(seq_logits / opt_temp_num) + seq_hard = P.OneHot()(seq_soft.argmax(-1), 20, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)) + seq_hard = ops.stop_gradient(seq_hard - seq_soft) + seq_soft + + seq_pseudo = opt_soft_num * seq_soft + (1 - opt_soft_num) * seq_input + + hard_mask = opt_hard_num + seq_pseudo = hard_mask * seq_hard + (1 - hard_mask) * seq_pseudo + seqs_res = (seq_logits, seq_pssm, seq_pseudo, seq_hard) + return seqs_res + + def update_seq(self, seq, msa_feat, ori_seq_len, seq_1hot=None, seq_pssm=None): + """update the sequence features""" + + if seq_1hot is None: + seq_1hot = seq + if seq_pssm is None: + seq_pssm = seq + + seq_1hot = mnp.pad(seq_1hot, [[0, 0], [0, self.crop_size - ori_seq_len], [0, 22 - seq_1hot.shape[-1]]]) + seq_pssm = mnp.pad(seq_pssm, [[0, 0], [0, self.crop_size - ori_seq_len], [0, 22 - seq_pssm.shape[-1]]]) + msa_feat = mnp.zeros_like(msa_feat, dtype=mstype.float32) + msa_feat[:seq_1hot.shape[0], :, 0:22] = seq_1hot + msa_feat[:seq_1hot.shape[0], :, 25:47] = seq_pssm + return msa_feat + + def construct(self, msa_feat, msa_mask, seq_mask, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, true_aatype, true_all_atom_positions, + true_all_atom_mask, opt_temp_num, opt_soft_num, opt_hard_num, + prev_pos, prev_msa_first_row, prev_pair): + """construct""" + seqs_res = self.soft_seq(self.seq_vector, self.ori_seq_len, + opt_temp_num, opt_soft_num, opt_hard_num) + seq_logits, seq_pssm, seq_pseudo, _ = seqs_res[0], seqs_res[1], seqs_res[2], seqs_res[3] + if self.opt_use_pssm: + pssm = seq_pssm + else: + pssm = seq_pseudo + msa_feat = self.update_seq(seq_pseudo, msa_feat, self.ori_seq_len, seq_pssm=pssm) + target_feat = msa_feat[0, :, :21] + target_feat = mnp.pad(target_feat, [[0, 0], [1, 0]]) + aatype = seq_pseudo[0].argmax(-1) + aatype = mnp.pad(aatype, [[0, self.crop_size - self.ori_seq_len]]) + + dist_logits, bin_edges, experimentally_logits, _, aligned_error_logits, \ + _, _, _, _, predicted_lddt_logits, _, _, _, \ + _, final_atom_positions = self.megafold(target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, + template_all_atom_positions, + template_mask, template_pseudo_beta_mask, + template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, + residue_index, + prev_pos, prev_msa_first_row, prev_pair) + loss_all = \ + self.loss_net(true_aatype, true_all_atom_positions, true_all_atom_mask, final_atom_positions, + self.ori_seq_len, dist_logits, bin_edges, experimentally_logits, predicted_lddt_logits, + aligned_error_logits, residue_index, seq_logits) + return loss_all diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/esm.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/esm.py new file mode 100644 index 0000000000000000000000000000000000000000..18dae86d4dda2c86fc8aaac615deb75a643b790a --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/esm.py @@ -0,0 +1,90 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""esm model""" +import mindspore as ms +from mindspore import jit, context, nn +from mindspore import ops +# pylint: disable=relative-beyond-top-level +from .module.esm_wrapcell import TrainOneStepCell +from .nn_arch import GVPTransformerModel as esm +from ..model import Model +from .module.util import Alphabet + + +class ESM(Model): + """ESM Model""" + name = "ESM" + + def __init__(self, config): + context.set_context(memory_optimize_level="O1", max_call_depth=6000) + if context.get_context("device_target") == "GPU": + self.mixed_precision = False + context.set_context(graph_kernel_flags="--disable_expand_ops=Softmax --disable_cluster_ops=ReduceSum " + "--composite_op_limit_size=50", enable_graph_kernel=True) + else: + self.mixed_precision = True + self.config = config + self.use_jit = self.config.use_jit + self.temperature = self.config.temperature + self.checkpoint_url = 'https://download.mindspore.cn/mindscience/mindsponge/esm/checkpoint/esm_if1.ckpt' + self.checkpoint_path = "./esm_if1.ckpt" + self.alphabet = Alphabet.from_architecture('vt_medium_with_invariant_gvp') + self.network = esm(self.config, self.alphabet) + + self.feature_list = ['coords', 'confidence', 'padding_mask', 'prev_output_tokens', 'target'] + loss = nn.CrossEntropyLoss() + net_with_loss = nn.WithLossCell(self.network, loss) + opt = nn.Adam(net_with_loss.trainable_params(), learning_rate=0.0001, eps=1e-6) + self.train_net = TrainOneStepCell(net_with_loss, opt) + self.train_net.set_train() + super().__init__(self.checkpoint_url, self.network, self.name) + + def forward(self, data): + if self.use_jit: + outputs = self._jit_forward(data) + else: + outputs = self._pynative_forward(data) + return outputs + + def predict(self, inputs): + sampled_seq = self.forward(inputs) + return sampled_seq + + def loss(self, data): + pass + + def grad_operations(self, gradient): + pass + + def backward(self, feat): + loss = self.train_net(feat) + return loss + + def train_step(self, data): + result = self.backward(data) + coord_mask = ops.IsFinite()(data['coords']).all(axis=-1).all(axis=-1) + coord_mask = coord_mask[:, 1:-1] + loss = ops.ReduceSum()(result * coord_mask) / ops.ReduceSum()(ops.Cast()(coord_mask, ms.float32)) + print("loss is:", loss) + return loss + + @jit + def _jit_forward(self, data): + sampled_seq = self.network.sample(data, self.temperature) + return sampled_seq + + def _pynative_forward(self, data): + sampled_seq = self.network.sample(data, self.temperature) + return sampled_seq diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/esm_dataset.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/esm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2e4d542c27b3d7d6bfe3f891994675395fcc1d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/esm_dataset.py @@ -0,0 +1,120 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""esm dataset""" +import math +import json +import numpy as np +from mindspore.dataset import GeneratorDataset +# pylint: disable=relative-beyond-top-level +from .module.util import load_coords, CoordBatchConverter +from .module.util import Alphabet +from ...dataset import PSP + + +class ESMDataSet(PSP): + """esm dataset""" + def __init__(self, config): + self.config = config + self.alphabet = Alphabet.from_architecture(self.config.arch) + self.feature_list = ['coords', 'confidence', 'padding_mask', 'prev_output_tokens', 'target'] + self.batch_size = self.config.batch_size + self.traindata = None + self.training_data_src = "" + self.coords, self.confidence, self.padding_mask, self.prev_output_tokens, self.target = \ + None, None, None, None, None + + super().__init__() + + def __getitem__(self, item): + output = [self.coords[item], self.confidence[item], self.padding_mask[item], + self.prev_output_tokens[item], self.target[item]] + return output + + def __len__(self): + return len(self.traindata) + + def process(self, pdbfile, chain="C"): + coords, _ = load_coords(pdbfile, chain) + return coords + + def data_generation(self, alphabet): + """Data generation""" + with open(self.training_data_src, "r") as f: + traindata = json.load(f) + f.close() + self.traindata = traindata + trainset = [] + for seq in self.traindata: + trainset.append(self.mask(0.15, seq, p=0.05)) + batch = [(e["coords"], None, e["seq"]) for e in trainset[:]] + batch_converter = CoordBatchConverter(alphabet) + coords, confidence, _, tokens, padding_mask = ( + batch_converter(batch) + ) + prev_output_tokens = tokens[:, :-1] + target = tokens[:, 1:] + prev_output_tokens = prev_output_tokens.astype(np.int32) + target = target.astype(np.int32) + output = [coords, confidence, padding_mask, + prev_output_tokens, target] + return output + + def mask(self, mask_ratio, sentence, lower=1, upper=10, p=0.05): + """Span masking""" + + sent_length = len(sentence['coords']) + mask_num = math.ceil(sent_length * mask_ratio) + mask = set() + while len(mask) < mask_num: + lens = list(range(lower, upper + 1)) + len_distrib = [p * (1 - p) ** (i - lower) for i in + range(lower, upper + 1)] if p >= 0 else None + len_distrib = [x / (sum(len_distrib)) for x in len_distrib] + span_len = np.random.choice(lens, p=len_distrib) + anchor = np.random.choice(sent_length) + if anchor in mask: + continue + for i in range(anchor, anchor + span_len): + if len(mask) >= mask_num or i >= sent_length: + break + mask.add(i) + + for num in mask: + rand = np.random.random() + if rand < 0.8: + sentence['coords'][num - 1] = [[float('inf'), float('inf'), float('inf')], + [float('inf'), float('inf'), float('inf')], + [float('inf'), float('inf'), float('inf')]] + elif rand < 0.9: + # sample random token according to input distribution + sentence['coords'][num - 1] = sentence['coords'][np.random.choice(sent_length)] + return sentence + + def test_data(self, seq_length): + pass + + def download(self): + pass + + def set_training_data_src(self, data_src): + self.training_data_src = data_src + + def create_iterator(self, num_epochs): + self.coords, self.confidence, self.padding_mask, self.prev_output_tokens, self.target = \ + self.data_generation(self.alphabet) + dataset = GeneratorDataset(source=self, column_names=self.feature_list, num_parallel_workers=1, shuffle=False) + dataset = dataset.batch(self.batch_size) + iteration = dataset.create_dict_iterator(num_epochs=num_epochs) + return iteration diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/basic_modules.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/basic_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b923d84b8f241f98774aed120651e93140823d44 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/basic_modules.py @@ -0,0 +1,931 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GVP operations, will be used in gvp_encoder.py""" +from typing import Dict, Optional, Tuple +import uuid +import math +import numpy as np +import mindspore as ms +import mindspore.ops as ops +from mindspore import nn, Tensor, Parameter +from mindspore.ops.primitive import Primitive +from mindspore._checkparam import Validator +from mindspore.nn.layer.activation import get_activation +from mindspore.common.initializer import Initializer, initializer,\ + XavierUniform, _calculate_fan_in_and_fan_out, _assignment +# pylint: disable=relative-beyond-top-level +from .message_passing import scatter_sum, MessagePassing +from .util import ms_transpose, _norm_no_nan, _split, tuple_cat, _merge, tuple_sum, tuple_index, utils_softmax + + +class XavierNormal(Initializer): + """Xavier normalization""" + + def __init__(self, gain=1): + super(XavierNormal, self).__init__(gain=gain) + self.gain = gain + + def _initialize(self, arr): + n_in, n_out = _calculate_fan_in_and_fan_out(arr.shape) + + std = self.gain * math.sqrt(2.0 / (n_in + n_out)) + data = np.random.normal(0, std, arr.shape) + + _assignment(arr, data) + + +class GVP(nn.Cell): + """GVP""" + + def __init__(self, in_dims, out_dims, h_dim=None, vector_gate=False, + activations=(ops.ReLU(), ops.Sigmoid()), tuple_io=True, + eps=1e-8): + super(GVP, self).__init__() + self.si, self.vi = in_dims + self.so, self.vo = out_dims + self.tuple_io = tuple_io + if self.vi: + self.h_dim = h_dim or max(self.vi, self.vo) + self.wh = Dense(self.vi, self.h_dim, has_bias=False) + self.ws = Dense(self.h_dim + self.si, self.so) + if self.vo: + self.wv = Dense(self.h_dim, self.vo, has_bias=False) + if vector_gate: + self.wg = Dense(self.so, self.vo) + else: + self.ws = Dense(self.si, self.so) + + self.vector_gate = vector_gate + self.scalar_act, self.vector_act = activations + self.eps = eps + + def construct(self, x): + """GVP construction""" + + if self.vi: + s, v = x + v = ms_transpose(v, (v.ndim - 1), (v.ndim - 2)) + vh = self.wh(v) + vn = _norm_no_nan(vh, axis=-2, eps=self.eps) + concat_op = ops.Concat(axis=-1) + s = self.ws(concat_op((s, vn))) + if self.scalar_act: + s = self.scalar_act(s) + if self.vo: + v = self.wv(vh) + v = ms_transpose(v, (v.ndim - 1), (v.ndim - 2)) + if self.vector_gate: + unsqueeze = ops.ExpandDims() + g = unsqueeze(self.wg(s), -1) + else: + g = _norm_no_nan(v, axis=-1, keepdims=True, eps=self.eps) + if self.vector_act: + g = self.vector_act(g) + v = v * g + else: + if self.tuple_io: + assert x[1] is None + x = x[0] + s = self.ws(x) + if self.scalar_act: + s = self.scalar_act(s) + if self.vo: + zeros = ops.Zeros() + v = zeros(list(s.shape)[:-1] + [self.vo, 3]) + + if self.vo: + return (s, v) + if self.tuple_io: + return (s, None) + return s + + +class _VDropout(nn.Cell): + """Dropout""" + + def __init__(self, drop_rate): + super(_VDropout, self).__init__() + self.drop_rate = drop_rate + self.dropout = nn.Dropout(drop_rate) + self.ones = ops.Ones() + self.unsqueeze = ops.ExpandDims() + + def construct(self, x): + """Dropout construction""" + + if x is None: + return None + if not self.training: + return x + a = self.ones(x.shape[:-1], x.dtype) + mask = self.dropout(a) + mask = self.unsqueeze(mask, -1) + x = mask * x / (1 - self.drop_rate) + return x + + +class Dropout(nn.Cell): + """Dropout""" + + def __init__(self, drop_rate): + super(Dropout, self).__init__() + self.sdropout = nn.Dropout(1 - drop_rate) + self.vdropout = _VDropout(1 - drop_rate) + + def construct(self, x): + if isinstance(x, ms.Tensor): + return self.sdropout(x) + s, v = x + return self.sdropout(s), self.vdropout(v) + + +class Dense(nn.Cell): + """ + preprocess input of each layer. + """ + + def __init__(self, + in_channels=None, + out_channels=None, + weight_init='normal', + bias_init='zeros', + has_bias=True, + activation=None): + super(Dense, self).__init__() + self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name) + self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name) + self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name) + self.reshape = ops.Reshape() + self.shape_op = ops.Shape() + + if isinstance(weight_init, Tensor): + if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ + weight_init.shape[1] != in_channels: + raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must " + f"be equal to 2, and the first dim must be equal to 'out_channels', and the " + f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, " + f"'out_channels': {out_channels}, 'in_channels': {in_channels}.") + self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") + + self.bias = None + if self.has_bias: + if isinstance(bias_init, Tensor): + if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: + raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must " + f"be equal to 1, and the first dim must be equal to 'out_channels'. But got " + f"'bias_init': {bias_init}, 'out_channels': {out_channels}.") + self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") + self.bias_add = ops.BiasAdd() + + self.matmul = ops.MatMul(transpose_b=True) + self.activation = get_activation(activation) if isinstance(activation, str) else activation + if activation is not None and not isinstance(self.activation, (nn.Cell, Primitive)): + raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got " + f"{type(activation).__name__}.") + self.activation_flag = self.activation is not None + + self.cast = ops.Cast() + self.get_dtype = ops.DType() + + def construct(self, x): + """Dense construction""" + x = self.cast(x, ms.float16) + + x_shape = self.shape_op(x) + if len(x_shape) != 2: + x = self.reshape(x, (-1, x_shape[-1])) + x = self.matmul(x, self.cast(self.weight, x.dtype)) + if self.has_bias: + x = self.bias_add(x, self.cast(self.bias, x.dtype)) + if self.activation_flag: + x = self.activation(x) + if len(x_shape) != 2: + out_shape = x_shape[:-1] + (-1,) + x = self.reshape(x, out_shape) + + x = self.cast(x, ms.float32) + return x + + +class LayerNorm(nn.Cell): + """Layer normalization""" + + def __init__(self, dims, tuple_io=True, eps=1e-8): + super(LayerNorm, self).__init__() + self.tuple_io = tuple_io + self.s, self.v = dims + self.scalar_norm = nn.LayerNorm([self.s]) + self.eps = eps + + def construct(self, x): + """Layer normalization construction""" + + if not self.v: + if self.tuple_io: + return self.scalar_norm(x[0]), None + return self.scalar_norm(x) + s, v = x + vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False, eps=self.eps) + nonzero_mask = (vn > 2 * self.eps) + vn = (vn * nonzero_mask) + nonzero_mask = ms.ops.Cast()(nonzero_mask, ms.float32) + v_1 = ops.ReduceSum(keep_dims=True)(vn, axis=-2) + v_2 = self.eps + ops.ReduceSum(keep_dims=True)(nonzero_mask, axis=-2) + vn = v_1 / v_2 + sqrt = ops.Sqrt() + vn = sqrt(vn + self.eps) + v = nonzero_mask * (v / vn) + return self.scalar_norm(s), v + + +class GVPConv(MessagePassing): + """GVP Convolution""" + + def __init__(self, in_dims, out_dims, edge_dims, n_layers=3, + vector_gate=False, module_list=None, aggr="mean", eps=1e-8, + activations=(ops.ReLU(), ops.Sigmoid())): + super(GVPConv, self).__init__() + self.eps = eps + self.si, self.vi = in_dims + self.so, self.vo = out_dims + self.se, self.ve = edge_dims + self.aggr = aggr + + module_list = module_list or [] + if not module_list: + if n_layers == 1: + module_list.append( + GVP((2 * self.si + self.se, 2 * self.vi + self.ve), + (self.so, self.vo), activations=(None, None))) + else: + module_list.append( + GVP((2 * self.si + self.se, 2 * self.vi + self.ve), out_dims, + vector_gate=vector_gate, activations=activations) + ) + for _ in range(n_layers - 2): + module_list.append(GVP(out_dims, out_dims, + vector_gate=vector_gate)) + module_list.append(GVP(out_dims, out_dims, + activations=(None, None))) + self.message_func = nn.SequentialCell(*module_list) + + def construct(self, x, edge_index, edge_attr): + x_s, x_v = x + message = self.propagate(x_s, edge_index, s=x_s, v=x_v.reshape(x_v.shape[0], 3 * x_v.shape[1]), + edge_attr=edge_attr, aggr=self.aggr) + output = _split(message, self.vo) + return output + + def message(self, s_i, v_i, s_j, v_j, edge_attr): + v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3) + v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3) + message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) + message = self.message_func(message) + output = _merge(*message) + return output + + +class GVPConvLayer(nn.Cell): + """GVP Convolution layer""" + + def __init__(self, node_dims, edge_dims, vector_gate=False, + n_message=3, n_feedforward=2, drop_rate=.1, + autoregressive=False, attention_heads=0, + conv_activations=(ops.ReLU(), ops.Sigmoid()), + n_edge_gvps=0, layernorm=True, eps=1e-8): + + super(GVPConvLayer, self).__init__() + if attention_heads == 0: + self.conv = GVPConv( + node_dims, node_dims, edge_dims, n_layers=n_message, + vector_gate=vector_gate, + aggr="add" if autoregressive else "mean", + activations=conv_activations, + eps=eps, + ) + else: + raise NotImplementedError + if layernorm: + self.norm = nn.CellList([LayerNorm(node_dims, eps=eps) for _ in range(2)]) + else: + self.norm = nn.CellList([nn.Identity() for _ in range(2)]) + self.dropout = nn.CellList([Dropout(drop_rate) for _ in range(2)]) + + ff_func = [] + if n_feedforward == 1: + ff_func.append(GVP(node_dims, node_dims, activations=(None, None))) + else: + hid_dims = 4 * node_dims[0], 2 * node_dims[1] + ff_func.append(GVP(node_dims, hid_dims, vector_gate=vector_gate)) + for _ in range(n_feedforward - 2): + ff_func.append(GVP(hid_dims, hid_dims, vector_gate=vector_gate)) + ff_func.append(GVP(hid_dims, node_dims, activations=(None, None))) + self.ff_func = nn.SequentialCell(*ff_func) + + self.edge_message_func = None + if n_edge_gvps > 0: + si, vi = node_dims + se, ve = edge_dims + module_list = [ + GVP((2 * si + se, 2 * vi + ve), edge_dims, vector_gate=vector_gate) + ] + for _ in range(n_edge_gvps - 2): + module_list.append(GVP(edge_dims, edge_dims, + vector_gate=vector_gate)) + if n_edge_gvps > 1: + module_list.append(GVP(edge_dims, edge_dims, + activations=(None, None))) + self.edge_message_func = nn.SequentialCell(*module_list) + if layernorm: + self.edge_norm = LayerNorm(edge_dims, eps=eps) + else: + self.edge_norm = nn.Identity() + self.edge_dropout = Dropout(drop_rate) + + def construct(self, x, edge_index, edge_attr, + autoregressive_x=None, node_mask=None): + """GVP Convolution layer construction""" + + if self.edge_message_func: + src, dst = edge_index + if autoregressive_x is None: + x_src = x[0][src], x[1][src] + else: + unsqueeze = ops.ExpandDims() + mask = (src < dst) + mask = unsqueeze(mask, -1) + x_src = ( + ms.numpy.where(mask, x[0][src], autoregressive_x[0][src]), + ms.numpy.where(unsqueeze(mask, -1), x[1][src], + autoregressive_x[1][src]) + ) + x_dst = x[0][dst], x[1][dst] + + x_edge = ( + ops.Concat(axis=-1)([x_src[0], edge_attr[0], x_dst[0]]), + ops.Concat(axis=-2)([x_src[1], edge_attr[1], x_dst[1]]) + ) + edge_attr_dh = self.edge_message_func(x_edge) + edge_attr = self.edge_norm(tuple_sum(edge_attr, + self.edge_dropout(edge_attr_dh))) + + if autoregressive_x is not None: + src, dst = edge_index + mask = src < dst + edge_index_forward = edge_index[:, mask] + edge_index_backward = edge_index[:, ~mask] + edge_attr_forward = tuple_index(edge_attr, mask) + edge_attr_backward = tuple_index(edge_attr, ~mask) + + dh = tuple_sum( + self.conv(x, edge_index_forward, edge_attr_forward), + self.conv(autoregressive_x, edge_index_backward, edge_attr_backward) + ) + unsqueeze = ops.ExpandDims() + + src = ops.OnesLike()(dst) + index = ms.Tensor(dst, ms.int32) + count = scatter_sum(src, index, dim_size=dh[0].shape[0]) + + min_value = ms.Tensor(1, ms.float32) + count = ops.clip_by_value(count, clip_value_min=min_value) + count = unsqueeze(count, -1) + + dh = dh[0] / count, unsqueeze((dh[1] / count), -1) + else: + dh = self.conv(x, edge_index, edge_attr) + + if node_mask is not None: + x_ = x + x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) + + x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) + + dh = self.ff_func(x) + + x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) + + if node_mask is not None: + x_[0][node_mask], x_[1][node_mask] = x[0], x[1] + x = x_ + + return x, edge_attr + + +class SinusoidalPositionalEmbedding(nn.Cell): + """Sinusoidal positional embedding""" + + def __init__(self, embed_dim, padding_idx): + super().__init__() + self.embed_dim = embed_dim + self.padding_idx = padding_idx + self._float_tensor = ms.Tensor(1, ms.float32) + self.weights = None + + def construct(self, x): + """Sinusoidal positional embedding construction""" + + bsz, seq_len = x.shape + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.shape[0]: + self.weights = self.get_embedding(max_pos) + self.weights = self.weights.astype(self._float_tensor.dtype) + + positions = self.make_positions(x) + positions = ops.Cast()(positions, ms.int32) + output = ops.gather(self.weights, positions.view((-1)), 0).view((bsz, seq_len, -1)) + return ops.stop_gradient(output) + + + def make_positions(self, x): + mask = ops.NotEqual()(x, self.padding_idx) + range_buf = ms.numpy.arange(x.shape[1]).expand_as(x) + self.padding_idx + 1 + positions = range_buf.expand_as(x) + floor = ops.Floor() + mask = ops.Cast()(mask, ms.float32) + return positions * floor(mask) + self.padding_idx * (1 - floor(mask)) + + def get_embedding(self, num_embeddings): + """Get sinusoidal positional embedding""" + + half_dim = self.embed_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = ops.Exp()(ms.numpy.arange(half_dim, dtype=ms.float32) * -emb) + unsqueeze = ops.ExpandDims() + emb = unsqueeze(ms.numpy.arange(num_embeddings, dtype=ms.float32), 1) * unsqueeze(emb, 0) + concat = ops.Concat(1) + emb = concat([ops.Sin()(emb), ops.Cos()(emb)]).view((num_embeddings, -1)) + if self.embed_dim % 2 == 1: + # zero pad + emb = concat([emb, ops.Zeros()((num_embeddings, 1), ms.float32)]) + if self.padding_idx is not None: + emb[self.padding_idx, :] = 0 + return emb + + +class FairseqIncrementalState: + """Fair sequence incremental state""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + +def with_incremental_state(cls): + """Incremental state""" + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls + + +@with_incremental_state +class MultiheadAttention(nn.Cell): + """Multihead attention""" + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = Dense(self.kdim, embed_dim, has_bias=bias) + self.v_proj = Dense(self.vdim, embed_dim, has_bias=bias) + self.q_proj = Dense(embed_dim, embed_dim, has_bias=bias) + + self.out_proj = Dense(embed_dim, embed_dim, has_bias=bias) + + if add_bias_kv: + self.bias_k = ms.Parameter(ms.Tensor((1, 1, embed_dim))) + self.bias_v = ms.Parameter(ms.Tensor((1, 1, embed_dim))) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + + self.enable_torch_version = True + + @staticmethod + def apply_sparse_mask(attn_weights): + return attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[ms.Tensor], + prev_key_padding_mask: Optional[ms.Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[ms.Tensor]: + """Append key padding masks""" + + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + prev_key_padding_mask = ops.Cast()(prev_key_padding_mask, ms.int32) + key_padding_mask = ops.Cast()(key_padding_mask, ms.int32) + new_key_padding_mask = ops.Concat(1)( + [prev_key_padding_mask, key_padding_mask] + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + filler = ops.Zeros()( + (batch_size, src_len - prev_key_padding_mask.shape[1]), prev_key_padding_mask.dtype + ) + prev_key_padding_mask = ops.Cast()(prev_key_padding_mask, ms.int32) + filler = ops.Cast()(filler, ms.int32) + new_key_padding_mask = ops.Concat(1)( + [prev_key_padding_mask, filler] + ) + elif key_padding_mask is not None: + filler = ops.Zeros()( + (batch_size, src_len - key_padding_mask.shape[1]), + ms.int32, + ) + + key_padding_mask = ops.Cast()(key_padding_mask, ms.int32) + if filler.shape == (1, 0): + new_key_padding_mask = key_padding_mask + else: + new_key_padding_mask = ops.concat((filler, key_padding_mask), 1) + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + """Reset parameters""" + + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + self.k_proj.weight = initializer(XavierUniform(gain=1 / math.sqrt(2)), + self.k_proj.weight.shape, self.k_proj.weight.dtype) + self.v_proj.weight = initializer(XavierUniform(gain=1 / math.sqrt(2)), + self.v_proj.weight.shape, self.v_proj.weight.dtype) + self.q_proj.weight = initializer(XavierUniform(gain=1 / math.sqrt(2)), + self.q_proj.weight.shape, self.q_proj.weight.dtype) + else: + self.k_proj.weight = initializer(XavierUniform(), self.k_proj.weight.shape, + self.k_proj.weight.dtype) + self.v_proj.weight = initializer(XavierUniform(), self.v_proj.weight.shape, + self.v_proj.weight.dtype) + self.q_proj.weight = initializer(XavierUniform(), self.q_proj.weight.shape, + self.q_proj.weight.dtype) + + self.out_proj.weight = initializer(XavierUniform(), self.out_proj.weight.shape, + self.out_proj.weight.dtype) + if self.out_proj.bias is not None: + ms.common.initializer.Constant(value=0.0)(self.out_proj.bias) + if self.bias_k is not None: + self.bias_k = initializer(XavierNormal(), self.bias_k.shape, + self.bias_k.dtype) + if self.bias_v is not None: + self.bias_v = initializer(XavierNormal(), self.bias_v.shape, + self.bias_v.dtype) + + def construct( + self, + query, + key: Optional[ms.Tensor], + value: Optional[ms.Tensor], + key_padding_mask: Optional[ms.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[ms.Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[ms.Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]: + """Multihead attention construction""" + + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = ops.Concat()([k, ms.numpy.tile(self.bias_k, (1, bsz, 1))]) + v = ops.Concat()([v, ms.numpy.tile(self.bias_v, (1, bsz, 1))]) + if attn_mask is not None: + attn_mask_zero = ops.Zeros()((attn_mask.shape[0], 1), attn_mask.dtype) + attn_mask = ops.Concat(1)( + [attn_mask, attn_mask_zero] + ) + if key_padding_mask is not None: + key_padding_mask_zero = ops.Zeros()((key_padding_mask.shape[0], 1), key_padding_mask.dtype) + key_padding_mask = ops.Concat(1)( + [ + key_padding_mask, + key_padding_mask_zero + ] + ) + + q = ms_transpose(q.view((tgt_len, bsz * self.num_heads, self.head_dim)), 0, 1) + if k is not None: + k = ms_transpose(k.view((-1, bsz * self.num_heads, self.head_dim)), 0, 1) + if v is not None: + v = ms_transpose(v.view((-1, bsz * self.num_heads, self.head_dim)), 0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + o_prev_key = saved_state.get("prev_key", " ") + assert o_prev_key is not None + prev_key = o_prev_key.view((bsz * self.num_heads, -1, self.head_dim)) + if static_kv: + k = prev_key + else: + assert k is not None + k = ops.Concat(1)([prev_key, k]) + if "prev_value" in saved_state: + o_prev_value = saved_state.get("prev_value", " ") + assert o_prev_value is not None + prev_value = o_prev_value.view((bsz * self.num_heads, -1, self.head_dim)) + if static_kv: + v = prev_value + else: + assert v is not None + v = ops.Concat(1)([prev_value, v]) + prev_key_padding_mask: Optional[ms.Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state.get("prev_key_padding_mask", " ") + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.shape[1], + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view((bsz, self.num_heads, -1, self.head_dim)) + saved_state["prev_value"] = v.view((bsz, self.num_heads, -1, self.head_dim)) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + src_len = k.shape[1] + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = ops.Concat(1)([k, ops.Zeros()(((k.shape[0], 1) + k.shape[2:]), k.dtype)]) + k = ops.Concat(1)([k, ops.Zeros()(((k.shape[0], 1) + k.shape[2:]), k.dtype)]) + v = ops.Concat(1)([v, ops.Zeros()(((v.shape[0], 1) + v.shape[2:]), v.dtype)]) + if attn_mask is not None: + attn_mask = ops.Concat(1)( + [attn_mask, ops.Zeros()((attn_mask.shape[0], 1), attn_mask.dtype)]) + if key_padding_mask is not None: + key_padding_mask = ops.Concat(1)( + [ + key_padding_mask, + ops.Zeros()((key_padding_mask.shape[0], 1), key_padding_mask.dtype), + ]) + + q = ops.Cast()(q, ms.float16) + k = ops.Cast()(ms_transpose(k, 1, 2), ms.float16) + attn_weights = ops.BatchMatMul()(q, k) + attn_weights = ops.Cast()(attn_weights, ms.float32) + + attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights) + + assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + unsqueeze = ops.ExpandDims() + if attn_mask is not None: + attn_mask = unsqueeze(attn_mask, 0) + if self.onnx_trace: + attn_mask = ms.numpy.tile(attn_mask, (attn_weights.shape[0], 1, 1)) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view((bsz, self.num_heads, tgt_len, src_len)) + key_padding_mask_ = unsqueeze(unsqueeze(key_padding_mask, 1), 2).astype(ms.bool_) + attn_weights = ops.MaskedFill()(attn_weights, key_padding_mask_, + ms.Tensor(-1e9, ms.float32)) + attn_weights = attn_weights.view((bsz * self.num_heads, tgt_len, src_len)) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) + attn_weights = attn_weights_float.astype(attn_weights.dtype) + + dropout_net = nn.Dropout(keep_prob=(1 - self.dropout)) + if self.training: + dropout_net.set_train() + attn_probs = dropout_net(attn_weights_float.astype(attn_weights.dtype)) + + assert v is not None + + attn_probs = ops.Cast()(attn_probs, ms.float16) + v = ops.Cast()(v, ms.float16) + attn = ops.BatchMatMul()(attn_probs, v) + attn = ops.Cast()(attn, ms.float32) + + assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.shape[1] == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.view((tgt_len, bsz, embed_dim)) + else: + attn = ms_transpose(attn, 0, 1).view((tgt_len, bsz, embed_dim)) + attn = self.out_proj(attn) + attn_weights: Optional[ms.Tensor] = None + if need_weights: + attn_weights = ms_transpose(attn_weights_float.view(( + bsz, self.num_heads, tgt_len, src_len + )), 1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(axis=0) + + return attn, attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade state dict name""" + + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[ms.Tensor]]]] + ) -> Dict[str, Optional[ms.Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + empty_result: Dict[str, Optional[ms.Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[ms.Tensor]]], + buffer: Dict[str, Optional[ms.Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + +def _set_input_buffer(selfattention: MultiheadAttention, + incremental_state: Dict[str, Dict[str, Optional[ms.Tensor]]], + buffer: Dict[str, Optional[ms.Tensor]], + ): + """Set input buffer""" + return selfattention.set_incremental_state(incremental_state, "attn_state", buffer) + + +def _get_input_buffer( + selfattention: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[ms.Tensor]]]] +) -> Dict[str, Optional[ms.Tensor]]: + """Get input buffer""" + result = selfattention.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + empty_result: Dict[str, Optional[ms.Tensor]] = {} + return empty_result diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/esm_wrapcell.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/esm_wrapcell.py new file mode 100644 index 0000000000000000000000000000000000000000..90e355c84c30ec3ccbf4d168bfa73728d49fb232 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/esm_wrapcell.py @@ -0,0 +1,41 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""TrainOneStepCell""" +import mindspore as ms +import mindspore.nn as nn + + +class TrainOneStepCell(nn.Cell): + """training""" + def __init__(self, network, optimizer): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.optimizer = optimizer + self.weights = self.optimizer.parameters + self.grad = ms.ops.GradOperation(get_by_list=True) + + def construct(self, inputs): + """Train net construction""" + loss = self.network((inputs['coords'], inputs['padding_mask'], inputs['confidence'], + inputs['prev_output_tokens']), label=inputs['target']) + grads = \ + self.grad(self.network, self.weights)((inputs['coords'], + inputs['padding_mask'], + inputs['confidence'], + inputs['prev_output_tokens']), + inputs['target']) + self.optimizer(grads) + return loss diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/features.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/features.py new file mode 100644 index 0000000000000000000000000000000000000000..c97bc098ea90371fd6e6ddd080da71b950e1444e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/features.py @@ -0,0 +1,366 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Feature extraction""" + +import math +import numpy as np +import mindspore as ms +import mindspore.ops as ops +import mindspore.nn as nn +# pylint: disable=relative-beyond-top-level +from .basic_modules import GVP, LayerNorm, Dense +from .util import normalize, norm, nan_to_num, rbf, flatten_graph, ms_transpose, ms_padding_without_val + + +class GVPInputFeaturizer(nn.Cell): + """Input feature extraction for GVP""" + + @staticmethod + def get_node_features(coords, coord_mask, with_coord_mask=True): + """Get node features""" + node_scalar_features = GVPInputFeaturizer._dihedrals(coords) + if with_coord_mask: + coord_mask = ops.ExpandDims()(ops.Cast()(coord_mask, ms.float32), -1) + node_scalar_features = ops.Concat(axis=-1)([node_scalar_features, coord_mask]) + x_ca = coords[:, :, 1] + orientations = GVPInputFeaturizer._orientations(x_ca) + sidechains = GVPInputFeaturizer._sidechains(coords) + node_vector_features = ops.Concat(axis=-2)([orientations, ops.ExpandDims()(sidechains, -2)]) + return node_scalar_features, node_vector_features + + @staticmethod + def _orientations(x): + + forward = normalize(x[:, 1:] - x[:, :-1]) + backward = normalize(x[:, :-1] - x[:, 1:]) + forward = ops.concat((forward, ops.Zeros()((forward.shape[0], 1, forward.shape[2]), ms.float32)), 1) + backward = ops.concat((ops.Zeros()((backward.shape[0], 1, backward.shape[2]), ms.float32), backward), 1) + + output = ops.Concat(axis=-2)([ops.ExpandDims()(forward, -2), ops.ExpandDims()(backward, -2)]) + return output + + @staticmethod + def _sidechains(x): + n, origin, c = x[:, :, 0], x[:, :, 1], x[:, :, 2] + c, n = normalize(c - origin), normalize(n - origin) + bisector = normalize(c + n) + perp = normalize(ms.numpy.cross(c, n)) + vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) + return vec + + @staticmethod + def _dihedrals(x, eps=1e-7): + """Dihedron""" + + y = x[:, :, :3].reshape((x.shape[0], (x.shape[1] * x.shape[2]), x.shape[3])) + bsz = x.shape[0] + dx = y[:, 1:] - y[:, :-1] + u = normalize(dx, dim=-1) + u_2 = u[:, :-2] + u_1 = u[:, 1:-1] + u_0 = u[:, 2:] + + # Backbone normals + n_2 = normalize(ms.numpy.cross(u_2, u_1), dim=-1) + n_1 = normalize(ms.numpy.cross(u_1, u_0), dim=-1) + + # Angle between normals + cosd = ops.ReduceSum()(n_2 * n_1, -1) + + min_value = ms.Tensor((-1 + eps), ms.float32) + max_value = ms.Tensor((1 - eps), ms.float32) + cosd = ops.clip_by_value(cosd, clip_value_min=min_value, clip_value_max=max_value) + d = ops.Sign()((u_2 * n_1).sum(-1)) * ops.ACos()(cosd) + + # This scheme will remove phi[0], psi[-1], omega[-1] + d = ms_padding_without_val(d, [1, 2]) + d = ops.Reshape()(d, (bsz, -1, 3)) + # Lift angle representations to the circle + d_features = ops.Concat(axis=-1)([ops.Cos()(d), ops.Sin()(d)]) + return d_features + + @staticmethod + def _positional_embeddings(edge_index, + num_embeddings=None, + num_positional_embeddings=16): + """Positional embeddings""" + + num_embeddings = num_embeddings or num_positional_embeddings or [] + d = edge_index[0] - edge_index[1] + + frequency = ops.Exp()( + ms.numpy.arange(0, num_embeddings, 2, dtype=ms.float32) + * -(np.log(10000.0) / num_embeddings) + ) + angles = ops.ExpandDims()(d, -1) * frequency + e = ops.Concat(-1)((ops.Cos()(angles), ops.Sin()(angles))) + return e + + @staticmethod + def _dist(x, coord_mask, padding_mask, top_k_neighbors): + """ Pairwise euclidean distances """ + bsz, maxlen = x.shape[0], x.shape[1] + coord_mask = ops.Cast()(coord_mask, ms.float32) + coord_mask_2d = ops.ExpandDims()(coord_mask, 1) * ops.ExpandDims()(coord_mask, 2) + residue_mask = ~padding_mask + residue_mask = ops.Cast()(residue_mask, ms.float32) + residue_mask_2d = ops.ExpandDims()(residue_mask, 1) * ops.ExpandDims()(residue_mask, 2) + dx = ops.ExpandDims()(x, 1) - ops.ExpandDims()(x, 2) + d = coord_mask_2d * norm(dx, dim=-1) + + # sorting preference: first those with coords, then among the residues that + # exist but are masked use distance in sequence as tie breaker, and then the + # residues that came from padding are last + seqpos = ms.numpy.arange(maxlen) + seqpos_1 = ops.ExpandDims()(seqpos, 1) + seqpos_0 = ops.ExpandDims()(seqpos, 0) + d_seq = ops.Abs()(seqpos_1 - seqpos_0) + if bsz != 1: + d_seq = ms.numpy.tile(d_seq, (bsz, 1, 1)) + coord_mask_2d = ops.Cast()(coord_mask_2d, ms.bool_) + residue_mask_2d = ops.Cast()(residue_mask_2d, ms.bool_) + verse_coord_mask_2d = ops.Cast()(~coord_mask_2d, ms.float32) + verse_residue_mask_2d = ops.Cast()(~residue_mask_2d, ms.float32) + d_adjust = nan_to_num(d) + (verse_coord_mask_2d) * (1e8 + d_seq * 1e6) + ( + verse_residue_mask_2d) * (1e10) + + if top_k_neighbors == -1: + d_neighbors = d_adjust / 1e4 + e_idx = seqpos.repeat( + *d_neighbors.shape[:-1], 1) + else: + d_adjust = d_adjust / 1e4 + d_neighbors, e_idx = ops.TopK(sorted=True)(d_adjust, d_adjust.shape[-1]) + d_neighbors, e_idx = d_neighbors[..., ::-1], e_idx[..., ::-1] + d_neighbors, e_idx = d_neighbors[:, :, 0:int(min(top_k_neighbors, x.shape[1]))], \ + e_idx[:, :, 0:int(min(top_k_neighbors, x.shape[1]))] + d_neighbors = ms.Tensor(d_neighbors, ms.float32)*1e4 + coord_mask_neighbors = (d_neighbors < 5e7) + residue_mask_neighbors = (d_neighbors < 5e9) + output = [d_neighbors, e_idx, coord_mask_neighbors, residue_mask_neighbors] + return output + + +class Normalize(nn.Cell): + """Normalization""" + + def __init__(self, features, epsilon=1e-6): + super(Normalize, self).__init__() + self.gain = ms.Parameter(ops.Ones()(features, ms.float32)) + self.bias = ms.Parameter(ops.Zeros()(features, ms.float32)) + self.epsilon = epsilon + + def construct(self, x, dim=-1): + """Normalization construction""" + + mu = x.mean(dim, keep_dims=True) + sigma = ops.Sqrt()(x.var(dim, keepdims=True) + self.epsilon) + gain = self.gain + bias = self.bias + # Reshape + if dim != -1: + shape = [1] * len(mu.size()) + shape[dim] = self.gain.size()[0] + gain = gain.view(shape) + bias = bias.view(shape) + return gain * (x - mu) / (sigma + self.epsilon) + bias + + +class DihedralFeatures(nn.Cell): + """Dihedral features""" + + def __init__(self, node_embed_dim): + """ Embed dihedral angle features. """ + super(DihedralFeatures, self).__init__() + # 3 dihedral angles; sin and cos of each angle + node_in = 6 + # Normalization and embedding + self.node_embedding = Dense(node_in, node_embed_dim, has_bias=True) + self.norm_nodes = Normalize(node_embed_dim) + + @staticmethod + def _dihedrals(x, eps=1e-7, return_angles=False): + """Dihedron in DihedralFeatures""" + + # First 3 coordinates are N, CA, C + x = x[:, :, :3, :].reshape(x.shape[0], 3 * x.shape[1], 3) + + # Shifted slices of unit vectors + dx = x[:, 1:, :] - x[:, :-1, :] + u = ops.L2Normalize(axis=-1)(dx) + u_2 = u[:, :-2, :] + u_1 = u[:, 1:-1, :] + u_0 = u[:, 2:, :] + # Backbone normals + n_2 = ops.L2Normalize(axis=-1)(ms.numpy.cross(u_2, u_1)) + n_1 = ops.L2Normalize(axis=-1)(ms.numpy.cross(u_1, u_0)) + + # Angle between normals + cosd = (n_2 * n_1).sum(-1) + min_value = ms.Tensor((-1 + eps), ms.float32) + max_value = ms.Tensor((1 - eps), ms.float32) + cosd = ops.clip_by_value(cosd, clip_value_min=min_value, clip_value_max=max_value) + d = ops.Sign()((u_2 * n_1).sum(-1)) * ops.ACos()(cosd) + + # This scheme will remove phi[0], psi[-1], omega[-1] + d = ms_padding_without_val(d, [1, 2]) + d = d.view((d.shape[0], int(d.shape[1] / 3), 3)) + phi, psi, omega = ops.Unstack(axis=-1)(d) + + if return_angles: + return phi, psi, omega + + # Lift angle representations to the circle + d_features = ops.Concat(axis=2)((ops.Cos()(d), ops.Sin()(d))) + return d_features + + def construct(self, x): + """ Featurize coordinates as an attributed graph """ + v = self._dihedrals(x) + v = self.node_embedding(v) + v = self.norm_nodes(v) + return v + + +class GVPGraphEmbedding(GVPInputFeaturizer): + """GVP graph embedding""" + + def __init__(self, args): + super().__init__() + self.top_k_neighbors = args.top_k_neighbors + self.num_positional_embeddings = 16 + self.remove_edges_without_coords = True + node_input_dim = (7, 3) + edge_input_dim = (34, 1) + node_hidden_dim = (args.node_hidden_dim_scalar, + args.node_hidden_dim_vector) + edge_hidden_dim = (args.edge_hidden_dim_scalar, + args.edge_hidden_dim_vector) + self.embed_node = nn.SequentialCell( + [GVP(node_input_dim, node_hidden_dim, activations=(None, None)), + LayerNorm(node_hidden_dim, eps=1e-4)] + ) + self.embed_edge = nn.SequentialCell( + [GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)), + LayerNorm(edge_hidden_dim, eps=1e-4)] + ) + self.embed_confidence = Dense(16, args.node_hidden_dim_scalar) + + def construct(self, coords, coord_mask, padding_mask, confidence): + """GVP graph embedding construction""" + + node_features = self.get_node_features(coords, coord_mask) + + edge_features, edge_index = self.get_edge_features( + coords, coord_mask, padding_mask) + node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features) + edge_embeddings = self.embed_edge(edge_features) + + rbf_rep = rbf(confidence, 0., 1.) + + node_embeddings = ( + node_embeddings_scalar + self.embed_confidence(rbf_rep), + node_embeddings_vector + ) + + + node_embeddings, edge_embeddings, edge_index = flatten_graph( + node_embeddings, edge_embeddings, edge_index) + return node_embeddings, edge_embeddings, edge_index + + def get_edge_features(self, coords, coord_mask, padding_mask): + """Get edge features""" + + x_ca = coords[:, :, 1] + + # Get distances to the top k neighbors + e_dist, e_idx, e_coord_mask, e_residue_mask = GVPInputFeaturizer._dist( + x_ca, coord_mask, padding_mask, self.top_k_neighbors) + # Flatten the graph to be batch size 1 for torch_geometric package + dest = e_idx + e_idx_b, e_idx_l, k = e_idx.shape[:3] + + src = ms.numpy.arange(e_idx_l).view((1, e_idx_l, 1)) + src = ops.BroadcastTo((e_idx_b, e_idx_l, k))(src) + + + edge_index = ops.Stack(axis=0)([src, dest]) + + edge_index = edge_index.reshape((edge_index.shape[0], edge_index.shape[1], + (edge_index.shape[2] * edge_index.shape[3]))) + + # After flattening, [B, E] + e_dist = e_dist.reshape((e_dist.shape[0], (e_dist.shape[1] * e_dist.shape[2]))) + + e_coord_mask = e_coord_mask.reshape((e_coord_mask.shape[0], (e_coord_mask.shape[1] * e_coord_mask.shape[2]))) + e_coord_mask = ops.ExpandDims()(e_coord_mask, -1) + e_residue_mask = e_residue_mask.reshape((e_residue_mask.shape[0], + (e_residue_mask.shape[1] * e_residue_mask.shape[2]))) + + # Calculate relative positional embeddings and distance RBF + pos_embeddings = GVPInputFeaturizer._positional_embeddings( + edge_index, + num_positional_embeddings=self.num_positional_embeddings, + ) + d_rbf = rbf(e_dist, 0., 20.) + + # Calculate relative orientation + x_src = ops.ExpandDims()(x_ca, 2) + x_src = ops.BroadcastTo((-1, -1, k, -1))(x_src) + x_src = x_src.reshape((x_src.shape[0], (x_src.shape[1] * x_src.shape[2]), x_src.shape[3])) + + a = ops.ExpandDims()(edge_index[1, :, :], -1) + a = ops.BroadcastTo((e_idx_b, e_idx_l * k, 3))(a) + x_dest = ops.GatherD()( + x_ca, + 1, + a + ) + coord_mask_src = ops.ExpandDims()(coord_mask, 2) + coord_mask_src = ops.BroadcastTo((-1, -1, k))(coord_mask_src) + coord_mask_src = coord_mask_src.reshape((coord_mask_src.shape[0], + (coord_mask_src.shape[1] * coord_mask_src.shape[2]))) + + b = ops.BroadcastTo((e_idx_b, e_idx_l * k))(edge_index[1, :, :]) + + coord_mask_dest = ops.GatherD()( + coord_mask, + 1, + b + ) + e_vectors = x_src - x_dest + # For the ones without coordinates, substitute in the average vector + e_coord_mask = ops.Cast()(e_coord_mask, ms.float32) + e_vector_mean = ops.ReduceSum(keep_dims=True) \ + (e_vectors * e_coord_mask, axis=1) / ops.ReduceSum(keep_dims=True)(e_coord_mask, axis=1) + e_coord_mask = ops.Cast()(e_coord_mask, ms.bool_) + e_vectors = e_vectors * e_coord_mask + e_vector_mean * ~(e_coord_mask) + # Normalize and remove nans + edge_s = ops.Concat(axis=-1)([d_rbf, pos_embeddings]) + edge_v = ops.ExpandDims()(normalize(e_vectors), -2) + edge_s, edge_v = map(nan_to_num, (edge_s, edge_v)) + # Also add indications of whether the coordinates are present + + edge_s = ops.Concat(axis=-1)([ + edge_s, + ops.ExpandDims()((~coord_mask_src).astype(np.float32), -1), + ops.ExpandDims()((~coord_mask_dest).astype(np.float32), -1)]) + e_residue_mask = ops.Cast()(e_residue_mask, ms.bool_) + edge_index = edge_index.masked_fill(~e_residue_mask, -1) + + if self.remove_edges_without_coords: + edge_index = ops.masked_fill(edge_index, ~e_coord_mask.squeeze(-1), -1) + + return (edge_s, edge_v), ms_transpose(edge_index, 0, 1) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/message_passing.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/message_passing.py new file mode 100644 index 0000000000000000000000000000000000000000..aaec8bd4c1fcb3103a2ecc519f00b75504268d4d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/message_passing.py @@ -0,0 +1,391 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Message passing""" +import re +import inspect +from inspect import Parameter +from typing import List, Optional, Any, Callable, Dict, Set, Tuple +from collections import OrderedDict +import pyparsing as pp +from mindspore import Tensor +import mindspore as ms +from mindspore import ops +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import Size + + +def param_type_repr(param) -> str: + """Return parameter type""" + if param.annotation is inspect.Parameter.empty: + return 'Tensor' + return sanitize(re.split(r':|='.strip(), str(param))[1]) + + +def split_types_repr(types_repr: str) -> List[str]: + """Split type""" + out = [] + i = depth = 0 + for j, char in enumerate(types_repr): + if char == '[': + depth += 1 + elif char == ']': + depth -= 1 + elif char == ',' and depth == 0: + out.append(types_repr[i:j].strip()) + i = j + 1 + out.append(types_repr[i:].strip()) + return out + + +def return_type_repr(signature) -> str: + """Return type""" + return_type = signature.return_annotation + if return_type is inspect.Parameter.empty: + return 'torch.Tensor' + if str(return_type)[:6] != ' List[Tuple[Dict[str, str], str]]: + """Return parse type""" + source = inspect.getsource(func) + signature = inspect.signature(func) + + iterator = re.finditer(r'#\s*type:\s*\((.*)\)\s*->\s*(.*)\s*\n', source) + matches = list(iterator) + + if matches: + out = [] + args = list(signature.parameters.keys()) + for match in matches: + arg_types_repr, return_type = match.groups() + arg_types = split_types_repr(arg_types_repr) + arg_types = OrderedDict((k, v) for k, v in zip(args, arg_types)) + return_type = return_type.split('#')[0].strip() + out.append((arg_types, return_type)) + return out + + # Alternatively, parse annotations using the inspected signature. + ps = signature.parameters + arg_types = OrderedDict((k, param_type_repr(v)) for k, v in ps.items()) + return [(arg_types, return_type_repr(signature))] + + +def sanitize(type_repr: str): + """Sanitize""" + type_repr = re.sub(r'', r'\1', type_repr) + type_repr = type_repr.replace('typing.', '') + type_repr = type_repr.replace('torch_sparse.tensor.', '') + type_repr = type_repr.replace('Adj', 'Union[Tensor, SparseTensor]') + + # Replace `Union[..., NoneType]` by `Optional[...]`. + sexp = pp.nestedExpr(opener='[', closer=']') + tree = sexp.parseString(f'[{type_repr.replace(",", " ")}]').asList()[0] + + def union_to_optional_(tree): + for i, _ in enumerate(tree): + e, n = tree[i], tree[i + 1] if i + 1 < len(tree) else [] + if e == 'Union' and n[-1] == 'NoneType': + tree[i] = 'Optional' + tree[i + 1] = tree[i + 1][:-1] + elif e == 'Union' and 'NoneType' in n: + idx = n.index('NoneType') + n[idx] = [n[idx - 1]] + n[idx - 1] = 'Optional' + elif isinstance(e, list): + tree[i] = union_to_optional_(e) + return tree + + tree = union_to_optional_(tree) + type_repr = re.sub(r'\'|\"', '', str(tree)[1:-1]).replace(', [', '[') + + return type_repr + + +class Inspector: + """Inspector""" + + def __init__(self, base_class: Any): + self.base_class: Any = base_class + self.params: Dict[str, Dict[str, Any]] = {} + + def __implements__(self, cls, func_name: str) -> bool: + if cls.__name__ == 'MessagePassing': + return False + if func_name in cls.__dict__.keys(): + return True + return any(self.__implements__(c, func_name) for c in cls.__bases__) + + def inspect(self, func: Callable, + pop_first: bool = False) -> Dict[str, Any]: + params = inspect.signature(func).parameters + params = OrderedDict(params) + if pop_first: + params.popitem(last=False) + self.params[func.__name__] = params + + def keys(self, func_names: Optional[List[str]] = None) -> Set[str]: + keys = [] + for func in func_names or list(self.params.keys()): + keys += self.params.get(func, " ").keys() + return set(keys) + + def implements(self, func_name: str) -> bool: + return self.__implements__(self.base_class.__class__, func_name) + + def types(self, func_names: Optional[List[str]] = None) -> Dict[str, str]: + """Return types""" + + out: Dict[str, str] = {} + for func_name in func_names or list(self.params.keys()): + func = getattr(self.base_class, func_name) + arg_types = parse_types(func)[0][0] + for key in self.params.get(func_name, " ").keys(): + if key in out and out.get(key, " ") != arg_types.get(key, " "): + raise ValueError( + (f'Found inconsistent types for argument {key}. ' + f'Expected type {out.get(key, " ")} but found type ' + f'{arg_types.get(key, " ")}.')) + out[key] = arg_types.get(key, " ") + return out + + def distribute(self, func_name, kwargs: Dict[str, Any]): + """Distribute""" + + out = {} + try: + for key, param in self.params.get(func_name, " ").items(): + data = kwargs.get(key, inspect.Parameter.empty) + if data is inspect.Parameter.empty: + data = param.default + out[key] = data + except KeyError: + raise TypeError(f'Required parameter {key} is empty.') + return out + + +def gather(params, indices, axis=None): + """Gather""" + if axis is None: + axis = 0 + if axis < 0: + axis = len(params.shape) + axis + if axis == 0: + return params[indices] + if axis == 1: + return params[:, indices] + if axis == 2: + return params[:, :, indices] + if axis == 3: + return params[:, :, :, indices] + raise ValueError("Unknown axis selected") + + +def broadcast(src: ms.Tensor, other: ms.Tensor, dim: int): + """Broadcast""" + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = ops.ExpandDims()(src, 0) + for _ in range(src.dim(), other.dim()): + src = ops.ExpandDims()(src, -1) + src = src.expand_as(other) + return src + + +def tensor_scatter_add(out, index, src, dim): + """Tensor scatter add""" + if dim < 0: + dim = out.ndim + dim + if out.ndim == 1: + out = ops.Cast()(out, ms.float32) + index = index.reshape(index.shape[0], 1) + src = ops.Cast()(src, ms.float32) + out = ops.scatter_nd_add(out, index, src) + elif out.ndim == 2: + if dim == 0: + m = index.shape[0] + n = index.shape[1] + index_new = index[:, :].reshape(-1)[:, None] + index_j = mnp.arange(n).astype(mnp.int32)[None,] + index_j = mnp.tile(index_j, (m, 1)).reshape(-1)[:, None] + index = mnp.concatenate((index_new, index_j), -1) # m*n, 2 + src = src[:, :].reshape(-1) # m*n, + out = ops.tensor_scatter_add(out, index, src) + return out + + +def scatter_sum(src: Tensor, index: Tensor, dim: int = -1, + out: Optional[Tensor] = None, + dim_size: Optional[int] = None) -> Tensor: + """Scatter sum""" + index = broadcast(index, src, dim) + index = index.astype(ms.int32) + if out is None: + size = list(src.shape) + if dim_size is not None: + size[dim] = dim_size + elif index.size() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = ops.Zeros()(tuple(size), src.dtype) + out = tensor_scatter_add(out, index, src, dim) + return out + out = tensor_scatter_add(out, index, src, dim) + return out + + +def scatter_mean(src: ms.Tensor, index: ms.Tensor, dim: int = -1, + out: Optional[ms.Tensor] = None, + dim_size: Optional[int] = None) -> ms.Tensor: + """Scatter mean""" + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.shape[dim] + + index_dim = dim + + if index_dim < 0: + index_dim = 0 + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = ops.Ones()(tuple(index.shape), ms.int32) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = broadcast(count, out, dim) + out = ms.numpy.true_divide(out, count) + return out + + +class MessagePassing(nn.Cell): + """Message passing class""" + + special_args = { + 'edge_index', 'x', 'edge_weight' + } + + def __init__(self, flow: str = "source_to_target", node_dim=-2): + super().__init__() + self.flow = flow + self.node_dim = node_dim + + self.inspector = Inspector(self) + self.inspector.inspect(self.message) + self.__user_args__ = \ + self.inspector.keys(['message',]).difference(self.special_args) + + def __check_input__(self, edge_index, size): + the_size: List[Optional[int]] = [None, None] + + if isinstance(edge_index, Tensor): + assert edge_index.dtype == ms.int32 + assert edge_index.dim() == 2 + assert edge_index.shape[0] == 2 + if size is not None: + the_size[0] = size[0] + the_size[1] = size[1] + return the_size + + raise ValueError( + ('`MessagePassing.propagate` only supports `torch.LongTensor` of ' + 'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for ' + 'argument `edge_index`.')) + + def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor): + the_size = size[dim] + if the_size is None: + size[dim] = src.shape[self.node_dim] + elif the_size != src.shape[self.node_dim]: + raise ValueError( + (f'Encountered tensor with size {src.shape[self.node_dim]} in ' + f'dimension {self.node_dim}, but expected size {the_size}.')) + + def __lift__(self, src, edge_index, dim): + if isinstance(edge_index, Tensor): + index = edge_index[dim] + return src.gather(index, self.node_dim) + raise ValueError + + def __collect__(self, args, edge_index, size, kwargs): + i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) + + out = {} + for arg in args: + if arg[-2:] not in ['_i', '_j']: + out[arg] = kwargs.get(arg, Parameter.empty) + else: + dim = j if arg[-2:] == '_j' else i + data = kwargs.get(arg[:-2], Parameter.empty) + + if isinstance(data, (tuple, list)): + assert len(data) == 2 + if isinstance(data[1 - dim], Tensor): + self.__set_size__(size, 1 - dim, data[1 - dim]) + data = data[dim] + + if isinstance(data, Tensor): + self.__set_size__(size, dim, data) + data = self.__lift__(data, edge_index, dim) + + out[arg] = data + + if isinstance(edge_index, Tensor): + out['adj_t'] = None + out['edge_index'] = edge_index + out['edge_index_i'] = edge_index[i] + out['edge_index_j'] = edge_index[j] + out['ptr'] = None + + out['index'] = out.get('edge_index_i', " ") + out['size'] = size + out['size_i'] = size[1] if size[1] is not None else size[0] + out['size_j'] = size[0] if size[0] is not None else size[1] + out['dim_size'] = out.get('size_i', " ") + return out + + def message_gvp(self, x, edge_index, edge_weight=None): + msg = gather(x, edge_index[0, :]) + if edge_weight is not None: + edge_weight = ops.ExpandDims()(edge_weight, -1) + return msg * edge_weight + return msg + + def aggregate(self, msg, edge_index, num_nodes=None, aggr='mean'): + dst_index = edge_index[1, :] + if aggr == 'mean': + return scatter_mean(msg, dst_index, dim=self.node_dim, dim_size=num_nodes) + raise NotImplementedError('Not support for this opearator') + + def update(self, x): + return x + + def propagate(self, x, edge_index, aggr='sum', size: Size = None, **kwargs): + """Propagate""" + if 'num_nodes' not in kwargs.keys() or kwargs.get('num_nodes', ' ') is None: + kwargs['num_nodes'] = x.shape[0] + size = self.__check_input__(edge_index, size) + coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs) + msg_kwargs = self.inspector.distribute('message', coll_dict) + msg = self.message(**msg_kwargs) + if aggr == 'mean': + x = self.aggregate(msg, edge_index, num_nodes=kwargs.get('num_nodes', ' '), aggr=aggr) + x = self.update(x) + return x diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/transformer_decoder.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f00cf2a6de9006ea32f23db6b63f9415b22a376d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/transformer_decoder.py @@ -0,0 +1,380 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Constructing decoder in transformer network""" + +import math +from typing import Dict, List, Optional + +import mindspore as ms +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore.common.initializer import Normal, initializer +from mindspore import Tensor +# pylint: disable=relative-beyond-top-level +from .basic_modules import Dense, SinusoidalPositionalEmbedding, \ +MultiheadAttention, _set_input_buffer, _get_input_buffer +from .util import ms_transpose + + +def fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return ops.Fill()(ms.float32, t.shape, float("-inf")) + + +class TransformerDecoder(nn.Cell): + """Transformer decoder""" + + def __init__( + self, + args, + dictionary, + embed_tokens, + ): + super().__init__() + self.args = args + self.dictionary = dictionary + self._future_mask = ms.numpy.empty((0)) + + self.dropout_module = nn.Dropout(1 - args.dropout) + + input_embed_dim = embed_tokens.embedding_size + embed_dim = args.decoder_embed_dim + self.embed_dim = embed_dim + self.padding_idx = embed_tokens.padding_idx + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) + + self.project_in_dim = ( + Dense(input_embed_dim, embed_dim, has_bias=False) + if embed_dim != input_embed_dim + else None + ) + self.embed_positions = SinusoidalPositionalEmbedding( + embed_dim, + self.padding_idx, + ) + + self.layers = nn.CellList([]) + self.layers.extend( + [ + self.build_decoder_layer(args) + for _ in range(args.decoder_layers) + ] + ) + self.num_layers = len(self.layers) + self.layer_norm = nn.LayerNorm([embed_dim]) + + self.build_output_projection(args, dictionary) + + def build_output_projection(self, args, dictionary): + self.output_projection = Dense( + args.decoder_embed_dim, len(dictionary), has_bias=False + ) + self.output_projection.weight = initializer(Normal(sigma=args.decoder_embed_dim ** -0.5, mean=0), + shape=self.output_projection.weight.shape, + dtype=self.output_projection.weight.dtype) + + def build_decoder_layer(self, args): + return TransformerDecoderLayer(args) + + def construct( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[ms.Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[ms.Tensor]]]] = None, + features_only: bool = False, + return_all_hiddens: bool = False, + ): + """Transformer decoder construction""" + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + ) + _ = return_all_hiddens + if not features_only: + x = self.output_layer(x) + x = ms_transpose(x, 1, 2) # B x T x C -> B x C x T + return x, extra + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[ms.Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[ms.Tensor]]]] = None, + ): + """Extract features""" + + bs, _ = prev_output_tokens.shape + + enc: Optional[ms.float32] = None + padding_mask: Optional[ms.float32] = None + if encoder_out is not None and encoder_out["encoder_out"]: + enc = encoder_out["encoder_out"][0] + assert ( + enc.shape[1] == bs + ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" + if encoder_out is not None and encoder_out["encoder_padding_mask"]: + padding_mask = encoder_out["encoder_padding_mask"][0] + + # embed positions + positions = self.embed_positions( + prev_output_tokens + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + positions = positions[:, -1:] + + # embed tokens and positions + prev_output_tokens = ops.Cast()(prev_output_tokens, ms.int32) + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + x += positions + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = ms_transpose(x, 0, 1) + + self_attn_padding_mask: Optional[ms.Tensor] = None + if ops.Equal()(prev_output_tokens, self.padding_idx).any(): + self_attn_padding_mask = ops.Equal()(prev_output_tokens, self.padding_idx) + + # decoder layers + inner_states: List[Optional[ms.Tensor]] = [x] + for _, layer in enumerate(self.layers): + if incremental_state is None: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, _, _ = layer( + x, + enc, + padding_mask, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=False, + need_head_weights=False, + ) + inner_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x C x T + x = ms_transpose(x, 0, 1) + + return x, {"inner_states": inner_states} + + def output_layer(self, features): + """Project features to the vocabulary size.""" + return self.output_projection(features) + + def buffered_future_mask(self, tensor): + """Buffered future mask""" + + dim = tensor.shape[0] + if ( + self._future_mask.shape[0] == 0 + or self._future_mask.shape[0] < dim + ): + self._future_mask = fill_with_neg_inf(ops.Zeros()((dim, dim), ms.float32)) + mask = ms.nn.Triu()(ms.ops.ones(self._future_mask.shape, ms.bool_), 1) + self._future_mask[ms.numpy.logical_not(mask)] = 0 + + self._future_mask = self._future_mask.astype(tensor.dtype) + return self._future_mask[:dim, :dim] + + +class TransformerDecoderLayer(nn.Cell): + """Transformer decoder layer""" + + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.dropout_module = nn.Dropout(1 - args.dropout) + + self.self_attn = self.build_self_attention( + self.embed_dim, + args, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + self.nh = self.self_attn.num_heads + self.head_dim = self.self_attn.head_dim + + self.activation_fn = ops.ReLU() + + self.self_attn_layer_norm = nn.LayerNorm([self.embed_dim]) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) + self.encoder_attn_layer_norm = nn.LayerNorm([self.embed_dim]) + + self.ffn_layernorm = ( + nn.LayerNorm([args.decoder_ffn_embed_dim]) + if getattr(args, "scale_fc", False) + else None + ) + self.w_resid = ( + ms.Parameter( + ops.Ones()( + self.embed_dim, + ), + requires_grad=True, + ) + if getattr(args, "scale_resids", False) + else None + ) + + self.fc1 = self.build_fc1( + self.embed_dim, + args.decoder_ffn_embed_dim, + ) + self.fc2 = self.build_fc2( + args.decoder_ffn_embed_dim, + self.embed_dim, + ) + + self.final_layer_norm = nn.LayerNorm([self.embed_dim]) + self.need_attn = True + + def build_fc1(self, input_dim, output_dim): + return Dense(input_dim, output_dim) + + def build_fc2(self, input_dim, output_dim): + return Dense(input_dim, output_dim) + + def build_self_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=True, + ) + + def build_encoder_attention(self, embed_dim, args): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + kdim=args.encoder_embed_dim, + vdim=args.encoder_embed_dim, + dropout=args.attention_dropout, + encoder_decoder_attention=True, + ) + + def residual_connection(self, x, residual): + return residual + x + + def construct( + self, + x, + encoder_out: Optional[ms.Tensor] = None, + encoder_padding_mask: Optional[ms.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + prev_self_attn_state: Optional[List[ms.Tensor]] = None, + prev_attn_state: Optional[List[ms.Tensor]] = None, + self_attn_mask: Optional[ms.Tensor] = None, + self_attn_padding_mask: Optional[ms.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """Transformer decoder layer construction""" + + if need_head_weights: + need_attn = True + + residual = x + x = self.self_attn_layer_norm(x) + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + _set_input_buffer(self.self_attn, incremental_state, saved_state) + _ = _get_input_buffer(self.self_attn, incremental_state) + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + + if self.encoder_attn is not None and encoder_out is not None: + residual = x + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + _set_input_buffer(self.encoder_attn, incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + + residual = x + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + x = self.fc2(x) + x = self.dropout_module(x) + if self.w_resid is not None: + residual = ops.Mul()(self.w_resid, residual) + x = self.residual_connection(x, residual) + return x, attn, None diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/transformer_encoder.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4aff0219ab993b25ca81d7957d1673a8653280 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/transformer_encoder.py @@ -0,0 +1,268 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Constructing encoder in transformer network""" + +import argparse +import math +from typing import Optional +import mindspore as ms +import mindspore.ops as ops +from mindspore import nn, Tensor +# pylint: disable=relative-beyond-top-level +from .features import GVPInputFeaturizer, DihedralFeatures, GVPGraphEmbedding +from .util import nan_to_num, get_rotation_frames, rotate, rbf, unflatten_graph, ms_transpose, ms_flatten +from .basic_modules import GVPConvLayer, MultiheadAttention, Dense, SinusoidalPositionalEmbedding + + +class TransformerEncoderLayer(nn.Cell): + """Transformer encoder layer""" + + def __init__(self, args): + super().__init__() + self.args = args + self.embed_dim = args.encoder_embed_dim + self.self_attn = self.build_self_attention(self.embed_dim, args) + self.self_attn_layer_norm = nn.LayerNorm([self.embed_dim]) + self.dropout_module = nn.Dropout(1 - args.dropout) + self.activation_fn = ops.ReLU() + self.fc1 = self.build_fc1( + self.embed_dim, + args.encoder_ffn_embed_dim, + ) + self.fc2 = self.build_fc2( + args.encoder_ffn_embed_dim, + self.embed_dim, + ) + + self.final_layer_norm = nn.LayerNorm([self.embed_dim]) + + def build_fc1(self, input_dim, output_dim): + return Dense(input_dim, output_dim) + + def build_fc2(self, input_dim, output_dim): + return Dense(input_dim, output_dim) + + def build_self_attention(self, embed_dim, args): + return MultiheadAttention( + embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + ) + + def residual_connection(self, x, residual): + return residual + x + + def construct( + self, + x, + encoder_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor] = None, + ): + """Transformer encoder layer construction""" + + if attn_mask is not None: + attn_mask = ms.ops.MaskedFill()(attn_mask, attn_mask.to(bool()), -1e8 if x.dtype == ms.float32 else -1e4) + + residual = x + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + need_weights=False, + attn_mask=attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + return x + + +class GVPEncoder(nn.Cell): + """GVP encoder""" + def __init__(self, args): + super().__init__() + self.args = args + self.embed_graph = GVPGraphEmbedding(args) + + node_hidden_dim = (args.node_hidden_dim_scalar, + args.node_hidden_dim_vector) + edge_hidden_dim = (args.edge_hidden_dim_scalar, + args.edge_hidden_dim_vector) + + conv_activations = (ops.ReLU(), ops.Sigmoid()) + self.encoder_layers = nn.CellList( + [GVPConvLayer( + node_hidden_dim, + edge_hidden_dim, + drop_rate=args.dropout, + vector_gate=True, + attention_heads=0, + n_message=3, + conv_activations=conv_activations, + n_edge_gvps=0, + eps=1e-4, + layernorm=True, + ) + for i in range(args.num_encoder_layers)] + ) + + def construct(self, coords, coord_mask, padding_mask, confidence): + node_embeddings, edge_embeddings, edge_index = self.embed_graph( + coords, coord_mask, padding_mask, confidence) + + for _, layer in enumerate(self.encoder_layers): + node_embeddings, edge_embeddings = layer(node_embeddings, + edge_index, edge_embeddings) + + node_embeddings = unflatten_graph(node_embeddings, coords.shape[0]) + return node_embeddings + + +class GVPTransformerEncoder(nn.Cell): + """GVP transformer encoder""" + + def __init__(self, args, dictionary, embed_tokens): + super().__init__() + self.args = args + self.dictionary = dictionary + + self.dropout_module = nn.Dropout(1 - args.dropout) + + embed_dim = embed_tokens.embedding_size + self.padding_idx = embed_tokens.padding_idx + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) + self.embed_positions = SinusoidalPositionalEmbedding( + embed_dim, + self.padding_idx, + ) + self.embed_gvp_input_features = Dense(15, embed_dim) + self.embed_confidence = Dense(16, embed_dim) + self.embed_dihedrals = DihedralFeatures(embed_dim) + + gvp_args = argparse.Namespace() + for k, v in vars(args).items(): + if k.startswith("gvp_"): + setattr(gvp_args, k[4:], v) + self.gvp_encoder = GVPEncoder(gvp_args) + gvp_out_dim = gvp_args.node_hidden_dim_scalar + \ + (3 * gvp_args.node_hidden_dim_vector) + self.embed_gvp_output = Dense(gvp_out_dim, embed_dim) + + self.layers = nn.CellList([]) + self.layers.extend( + [self.build_encoder_layer(args) for i in range(args.encoder_layers)] + ) + self.num_layers = len(self.layers) + self.layer_norm = nn.LayerNorm([embed_dim]) + + def build_encoder_layer(self, args): + return TransformerEncoderLayer(args) + + def forward_embedding(self, coords, padding_mask, confidence): + """GVP transformer encoder embedding""" + + components = dict() + coord_mask = ops.IsFinite()(coords).all(axis=-1).all(axis=-1) + coords = nan_to_num(coords) + mask_tokens = ( + padding_mask * self.dictionary.padding_idx + + ~padding_mask * self.dictionary.get_idx("") + ) + components["tokens"] = self.embed_tokens(mask_tokens) * self.embed_scale + components["diherals"] = self.embed_dihedrals(coords) + + # GVP encoder + gvp_out_scalars, gvp_out_vectors = \ + self.gvp_encoder(coords, coord_mask, padding_mask, confidence) + r = get_rotation_frames(coords) + # Rotate to local rotation frame for rotation-invariance + gvp_out_features = ops.Concat(-1)([ + gvp_out_scalars, + ms_flatten(rotate(gvp_out_vectors, ms_transpose(r, r.dim()-2, r.dim()-1)), -2, -1), + ]) + components["gvp_out"] = self.embed_gvp_output(gvp_out_features) + + components["confidence"] = self.embed_confidence( + rbf(confidence, 0., 1.)) + + # In addition to GVP encoder outputs, also directly embed GVP input node + # features to the Transformer + scalar_features, vector_features = GVPInputFeaturizer.get_node_features( + coords, coord_mask, with_coord_mask=False) + features = ops.Concat(-1)([ + scalar_features, + ms_flatten(rotate(vector_features, ms_transpose(r, r.dim()-2, r.dim()-1)), -2, -1), + ]) + components["gvp_input_features"] = self.embed_gvp_input_features(features) + + embed = sum(components.values()) + + x = embed + x = x + self.embed_positions(mask_tokens) + x = self.dropout_module(x) + return x, components + + def construct( + self, + coords, + encoder_padding_mask, + confidence, + return_all_hiddens: bool = False, + ): + """GVP transformer encoder construction""" + + x, encoder_embedding = \ + self.forward_embedding(coords, encoder_padding_mask, confidence) + # account for padding while computing the representation + unsqueeze = ops.ExpandDims() + x = x * (1 - unsqueeze(encoder_padding_mask, -1).astype(x.dtype)) + + # B x T x C -> T x B x C + x = ms_transpose(x, 0, 1) + + encoder_states = [] + + if return_all_hiddens: + encoder_states.append(x) + + # encoder layers + for layer in self.layers: + x = layer( + x, encoder_padding_mask=encoder_padding_mask + ) + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [encoder_embedding], # dictionary + "encoder_states": encoder_states, # List[T x B x C] + } diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/util.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/util.py new file mode 100644 index 0000000000000000000000000000000000000000..555ea9e30067694e95307645421b6d12fac54d2d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/module/util.py @@ -0,0 +1,635 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Some functions used in transformer network""" + +import itertools +from typing import Sequence, Tuple, List +import biotite.structure +from biotite.structure.io import pdbx, pdb +from biotite.structure.residues import get_residues +from biotite.structure import filter_backbone +from biotite.structure import get_chains +from biotite.sequence import ProteinSequence +import numpy as np +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +proteinseq_toks = { + 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', + 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-'] +} + + +def ms_transpose(x, index_a, index_b): + """Transpose""" + index = list(i for i in range(len(x.shape))) + index[index_a] = index_b + index[index_b] = index_a + input_trans = x.transpose(index) + return input_trans + + +def ms_sum(x, dim, keep_dims=False): + """Sum""" + op = ms.ops.ReduceSum(keep_dims=keep_dims) + return op(x, dim) + + +def ms_padding_without_val(x, padding): + """Padding""" + paddings = () + num = int(len(padding) / 2) + zero_pad = len(x.shape) - num + i = int(0) + while i < zero_pad: + i += 1 + paddings = paddings + ((0, 0),) + for j in range(num): + paddings = paddings + ((padding[(-2) * j - 2], padding[(-2) * j - 1]),) + y = ms.nn.Pad(paddings=paddings)(x) + return y + + +def ms_padding(x, num, val): + """Padding""" + num_shape = len(x.shape) + if num == -3: + a = np.ones(shape=x.shape[num + 1:]).astype(np.float32) + a[:] = val + x_pad = ms.Tensor(a) + elif num == -1: + x_pad = ms.Tensor(val) + else: + print("wrong with num, it should be -1 or -3") + pad_tuple = list((0, 0) for i in range(num_shape)) + pad_tuple[num] = (1, 1) + pad_op = ms.nn.Pad(paddings=tuple(pad_tuple)) + output = pad_op(x) + if num == -3: + output[..., 0, :, :] = x_pad + output[..., -1, :, :] = x_pad + elif num == -1: + output[..., 0] = x_pad + output[..., -1] = x_pad + else: + output[..., -1] = x_pad + return output + + +def load_structure(fpath, chain=None): + """Load structure""" + if fpath.endswith('cif'): + with open(fpath) as fin: + pdbxf = pdbx.PDBxFile.read(fin) + structure = pdbx.get_structure(pdbxf, model=1) + elif fpath.endswith('pdb'): + with open(fpath) as fin: + pdbf = pdb.PDBFile.read(fin) + structure = pdb.get_structure(pdbf, model=1) + bbmask = filter_backbone(structure) + structure = structure[bbmask] + chains = get_chains(structure) + print(f'Found {len(chains)} chains:', chains, '\n') + if not list(chains): + raise ValueError('No chains found in the input file.') + if chain is None: + chain = chains[0] + if chain not in chains: + raise ValueError(f'Chain {chain} not found in input file') + structure = structure[structure.chain_id == chain] + print(f'Loaded chain {chain}\n') + return structure + + +def extract_coords_from_structure(structure: biotite.structure.AtomArray): + """Extract coordinates from structure""" + coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) + residue_identities = get_residues(structure)[1] + seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) + return coords, seq + + +def load_coords(fpath, chain): + """Load coordinates""" + structure = load_structure(fpath, chain) + return extract_coords_from_structure(structure) + + +def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): + """ + Example for atoms argument: ["N", "CA", "C"] + """ + + def filterfn(s, axis=None): + _ = axis + filters = np.stack([s.atom_name == name for name in atoms], axis=1) + filter_sum = filters.sum(0) + if not np.all(filter_sum <= np.ones(filters.shape[1])): + raise RuntimeError("structure has multiple atoms with same name") + index = filters.argmax(0) + coords = s[index].coord + coords[filter_sum == 0] = float("nan") + return coords + + return biotite.structure.apply_residue_wise(struct, struct, filterfn) + + +def score_sequence(model, alphabet, coords, seq): + """Score sequences for given structure""" + batch_converter = CoordBatchConverter(alphabet) + batch = [(coords, None, seq)] + coords, confidence, _, tokens, padding_mask = batch_converter(batch) + prev_output_tokens = tokens[:, :-1] + target = tokens[:, 1:] + target_padding_mask = (target == alphabet.padding_idx) + coords = Tensor(coords) + padding_mask = Tensor(padding_mask) + confidence = Tensor(confidence) + prev_output_tokens = Tensor(prev_output_tokens) + model_input = (coords, padding_mask, confidence, prev_output_tokens) + logits = model.construct(model_input) + target = ms.ops.Cast()(target, ms.int32) + loss = nn.CrossEntropyLoss(reduction='none')(logits, target) + avgloss = ms_sum(loss * ~target_padding_mask, dim=-1) / ms_sum(ops.Cast()(~target_padding_mask, ms.float32), dim=-1) + ll_fullseq = -avgloss.asnumpy().item() + + coord_bool = ms.ops.isfinite(coords) + coord_mask = coord_bool.all(axis=-1).all(axis=-1) + coord_mask = coord_mask[:, 1:-1] + avgloss = ms_sum(loss * coord_mask, dim=-1) / ms_sum(ops.Cast()(coord_mask, ms.float32), dim=-1) + ll_withcoord = -avgloss.asnumpy().item() + + return ll_fullseq, ll_withcoord + + +def get_encoder_output(model, alphabet, coords): + """Get encoder output""" + batch_converter = CoordBatchConverter(alphabet) + # the batch_converter is essential for forming the correct input format + batch = [(coords, None, None)] + coords, confidence, _, _, padding_mask = batch_converter(batch) + coords = Tensor(coords) + confidence = Tensor(confidence) + padding_mask = Tensor(padding_mask) + encoder_out = \ + model.encoder.construct(coords, padding_mask, confidence, return_all_hiddens=False) + # remove beginning and end (bos and eos tokens) + return encoder_out['encoder_out'][0][1:-1, 0] + + +def rotate(v, r): + """Rotate""" + unsqueeze = ms.ops.ExpandDims() + r = unsqueeze(r, -3) + v = unsqueeze(v, -1) + return ms_sum(v * r, dim=-2) + + +def get_rotation_frames(coords): + """Get rotation frames""" + v1 = coords[:, :, 2] - coords[:, :, 1] + v2 = coords[:, :, 0] - coords[:, :, 1] + e1 = normalize(v1, dim=-1) + u2 = v2 - e1 * ms_sum(e1 * v2, dim=-1, keep_dims=True) + e2 = normalize(u2, dim=-1) + e3 = ms.numpy.cross(e1, e2) + stack = ms.ops.Stack(axis=-2) + r = stack([e1, e2, e3]) + return r + + +def utils_softmax(x, dim: int, onnx_trace: bool = False): + """Utils softmax""" + if onnx_trace: + return ops.Softmax(axis=dim)(ops.Cast()(x, ms.float32)) + x = x.astype(ms.float32) + return ops.Softmax(axis=dim)(x) + + +def tuple_size(tp): + """Return tuple size""" + return tuple([0 if a is None else a.size() for a in tp]) + + +def tuple_sum(tp1, tp2): + """Return the sum of tuple""" + s1, v1 = tp1 + s2, v2 = tp2 + if v2 is None and v2 is None: + return (s1 + s2, None) + return (s1 + s2, v1 + v2) + + +def tuple_cat(*args, dim=-1): + """Return the concat of tuple""" + dim %= len(args[0][0].shape) + s_args, v_args = list(zip(*args)) + concat_op = ops.Concat(axis=dim) + return concat_op(s_args), concat_op(v_args) + + +def tuple_index(x, idx): + """Return the index of tuple""" + return x[0][idx], x[1][idx] + + +def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): + square = ops.Square() + ops_sum = ops.ReduceSum(keep_dims=keepdims) + sqrt_1 = ops.Sqrt() + out = ops_sum(square(x), axis) + eps + return sqrt_1(out) if sqrt else out + + +def _split(x, nv): + """Split""" + reshape = ops.Reshape() + v = reshape(x[..., -3 * nv:], x.shape[:-1] + (nv, 3)) + s = x[..., :-3 * nv] + return s, v + + +def _merge(s, v): + """Merge""" + reshape = ops.Reshape() + v = reshape(v, v.shape[:-2] + (3 * v.shape[-2],)) + concat_op = ops.Concat(axis=-1) + a = concat_op((s, v)) + return a + + +def nan_to_num(ts, val=0.0): + val = ms.Tensor(val, dtype=ts.dtype) + return ms.numpy.where(~ms.ops.IsFinite()(ts), val, ts) + + +def rbf(values, v_min, v_max, n_bins=16): + """Radial basis function""" + linspace = ms.ops.LinSpace() + v_min = ms.Tensor(v_min, ms.float32) + v_max = ms.Tensor(v_max, ms.float32) + rbf_centers = linspace(v_min, v_max, n_bins) + rbf_centers = rbf_centers.view(tuple([1] * len(values.shape) + [-1])) + rbf_std = (v_max - v_min) / n_bins + expand_dims = ms.ops.ExpandDims() + v_expand = expand_dims(values, -1) + z = (v_expand - rbf_centers) / rbf_std + exp = ms.ops.Exp() + return exp(-z ** 2) + + +def norm(tensor, dim, eps=1e-8, keepdim=False): + sqrt = ms.ops.Sqrt() + square = ms.ops.Square() + return sqrt( + (ops.ReduceSum(keep_dims=keepdim)(square(tensor), axis=dim) + eps)) + + +def normalize(tensor, dim=-1): + """Normalization""" + div = ms.ops.Div() + y = norm(tensor, dim=dim, keepdim=True) + return nan_to_num( + div(tensor, y) + ) + + +def ms_flatten(input_tensor, start_dim, end_dim): + """Flatten""" + if start_dim == 0: + shape_list = list(input_tensor.shape[end_dim + 1:]) + dim = 1 + for i in range(start_dim, end_dim + 1): + dim = input_tensor.shape[i] * dim + shape_list.insert(0, dim) + shape_list = tuple(shape_list) + flatten = ms.ops.Reshape() + output = flatten(input_tensor, shape_list) + return output + if end_dim in (-1, input_tensor.dim() - 1): + shape_list = list(input_tensor.shape[:start_dim]) + dim = 1 + for i in range(start_dim, end_dim + 1): + dim = input_tensor.shape[i] * dim + shape_list.append(dim) + shape_list = tuple(shape_list) + flatten = ms.ops.Reshape() + output = flatten(input_tensor, shape_list) + return output + raise ValueError("Unknown dim selected") + + +def flatten_graph(node_embeddings, edge_embeddings, edge_index): + """Flatten graph""" + x_s, x_v = node_embeddings + e_s, e_v = edge_embeddings + batch_size, n = x_s.shape[0], x_s.shape[1] + node_embeddings = (x_s.reshape(((x_s.shape[0] * x_s.shape[1]), x_s.shape[2])), + x_v.reshape(((x_v.shape[0] * x_v.shape[1]), x_v.shape[2], x_v.shape[3]))) + edge_embeddings = (e_s.reshape(((e_s.shape[0] * e_s.shape[1]), e_s.shape[2])), + e_v.reshape(((e_v.shape[0] * e_v.shape[1]), e_v.shape[2], e_v.shape[3]))) + new_edge_index = ops.Cast()(edge_index != -1, ms.bool_) + edge_mask = new_edge_index.any(axis=1) + + # Re-number the nodes by adding batch_idx * N to each batch + unsqueeze = ops.ExpandDims() + edge_index = edge_index + unsqueeze(unsqueeze((ms.numpy.arange(batch_size) * n), -1), -1) + + permute = ops.Transpose() + + edge_index = permute(edge_index, (1, 0, 2)) + edge_index = edge_index.reshape(edge_index.shape[0], (edge_index.shape[1] * edge_index.shape[2])) + + edge_mask = edge_mask.flatten() + edge_mask = edge_mask.asnumpy() + edge_index = edge_index.asnumpy() + edge_embeddings_0 = edge_embeddings[0].asnumpy() + edge_embeddings_1 = edge_embeddings[1].asnumpy() + + edge_index = edge_index[:, edge_mask] + edge_embeddings = ( + ms.Tensor(edge_embeddings_0[edge_mask, :], ms.float32), + ms.Tensor(edge_embeddings_1[edge_mask, :], ms.float32) + ) + + edge_index = ms.Tensor(edge_index, ms.int32) + return node_embeddings, edge_embeddings, edge_index + + +def unflatten_graph(node_embeddings, batch_size): + """Unflatten graph""" + x_s, x_v = node_embeddings + x_s = x_s.reshape((batch_size, -1, x_s.shape[1])) + x_v = x_v.reshape((batch_size, -1, x_v.shape[1], x_v.shape[2])) + return (x_s, x_v) + + +class Alphabet: + """Create alphabet""" + def __init__( + self, + standard_toks: Sequence[str], + prepend_toks: Sequence[str] = ("", "", "", ""), + append_toks: Sequence[str] = ("", "", ""), + prepend_bos: bool = True, + append_eos: bool = False, + use_msa: bool = False, + ): + self.standard_toks = list(standard_toks) + self.prepend_toks = list(prepend_toks) + self.append_toks = list(append_toks) + self.prepend_bos = prepend_bos + self.append_eos = append_eos + self.use_msa = use_msa + + self.all_toks = list(self.prepend_toks) + self.all_toks.extend(self.standard_toks) + for i in range((8 - (len(self.all_toks) % 8)) % 8): + self.all_toks.append(f"") + self.all_toks.extend(self.append_toks) + + self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} + + self.unk_idx = self.tok_to_idx[""] + self.padding_idx = self.get_idx("") + self.cls_idx = self.get_idx("") + self.mask_idx = self.get_idx("") + self.eos_idx = self.get_idx("") + self.all_special_tokens = ['', '', '', '', ''] + self.unique_no_split_tokens = self.all_toks + + def __len__(self): + return len(self.all_toks) + + @classmethod + def from_architecture(cls, name: str) -> "Alphabet": + """Return alphabet""" + + if "invariant_gvp" in name.lower(): + standard_toks = proteinseq_toks.get("toks", "abc") + prepend_toks = ("", "", "", "") + append_toks = ("", "", "") + prepend_bos = True + append_eos = False + use_msa = False + else: + raise ValueError("Unknown architecture selected") + return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa) + + @staticmethod + def _tokenize(text) -> str: + return text.split() + + def get_idx(self, tok): + return self.tok_to_idx.get(tok, self.unk_idx) + + def get_tok(self, ind): + return self.all_toks[ind] + + def get_batch_converter(self): + return BatchConverter(self) + + def tokenize(self, text) -> List[str]: + """Tokenization""" + + def split_on_token(tok, text): + result = [] + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + if i < len(split_text) - 1: + sub_text = sub_text.rstrip() + if i > 0: + sub_text = sub_text.lstrip() + + if i == 0 and not sub_text: + result.append(tok) + elif i == len(split_text) - 1: + if sub_text: + result.append(sub_text) + else: + pass + else: + if sub_text: + result.append(sub_text) + result.append(tok) + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self.unique_no_split_tokens: + tokenized_text.extend(split_on_token(tok, sub_text)) + else: + tokenized_text.append(sub_text) + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + ( + self._tokenize(token) + if token not in self.unique_no_split_tokens + else [token] + for token in tokenized_text + ) + ) + ) + + no_split_token = self.unique_no_split_tokens + tokenized_text = split_on_tokens(no_split_token, text) + return tokenized_text + + def to_dict(self): + return self.tok_to_idx.copy() + + def encode(self, text): + return [self.tok_to_idx[tok] for tok in self.tokenize(text)] + + +def np_padding(x, num, val): + """Padding""" + num_shape = len(x.shape) + if num == -3: + a = np.ones(shape=x.shape[num + 1:]).astype(np.float32) + a[:] = val + x_pad = a + elif num == -1: + x_pad = val + else: + print("wrong with num, it should be -1 or -3") + pad_tuple = list((0, 0) for i in range(num_shape)) + pad_tuple[num] = (1, 1) + output = np.pad(x, pad_tuple) + if num == -3: + output[..., 0, :, :] = x_pad + output[..., -1, :, :] = x_pad + elif num == -1: + output[..., 0] = x_pad + output[..., -1] = x_pad + else: + output[..., -1] = x_pad + return output + + +class BatchConverter: + """Batch conversion""" + + def __init__(self, alphabet): + self.alphabet = alphabet + + def __call__(self, raw_batch: Sequence[Tuple[str, str]]): + # RoBERTa uses an eos token, while ESM-1 does not. + batch_size = len(raw_batch) + batch_labels, seq_str_list = zip(*raw_batch) + seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] + max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list) + tokens = np.ones((batch_size, max_len + int(self.alphabet.prepend_bos) + + int(self.alphabet.append_eos))).astype(np.float32) * self.alphabet.padding_idx + + labels = [] + strs = [] + + for i, (label, seq_str, seq_encoded) in enumerate( + zip(batch_labels, seq_str_list, seq_encoded_list) + ): + labels.append(label) + strs.append(seq_str) + if self.alphabet.prepend_bos: + tokens[i, 0] = self.alphabet.cls_idx + seq = np.array(seq_encoded).astype(np.float32) + tokens[ + i, + int(self.alphabet.prepend_bos) : len(seq_encoded) + + int(self.alphabet.prepend_bos), + ] = seq + if self.alphabet.append_eos: + tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx + + return labels, strs, tokens + + +class CoordBatchConverter(BatchConverter): + """Batch conversion of coordinates""" + + def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None): + self.alphabet.cls_idx = self.alphabet.get_idx("") + batch = [] + for coords, confidence, seq in raw_batch: + if confidence is None: + confidence = 1. + if isinstance(confidence, (float, int)): + confidence = [float(confidence)] * len(coords) + if seq is None: + seq = 'X' * len(coords) + batch.append(((coords, confidence), seq)) + + coords_and_confidence, strs, tokens = super().__call__(batch) + + # pad beginning and end of each protein due to legacy reasons + coords = [ + np_padding(np.array(cd), num=-3, val=np.inf) + for cd, _ in coords_and_confidence + ] + confidence = [ + np_padding(np.array(cf), num=-1, val=-1.) + for _, cf in coords_and_confidence + ] + coords = self.collate_dense_tensors(coords, pad_v=np.nan) + + confidence = self.collate_dense_tensors(confidence, pad_v=-1.) + padding_mask = np.isnan(coords[:, :, 0, 0]) + coord_mask = np.isfinite(coords.sum(-2).sum(-1)) + confidence = confidence * coord_mask + (-1.) * padding_mask + output = [coords, confidence, strs, tokens, padding_mask] + + return output + + @staticmethod + def collate_dense_tensors(samples, pad_v): + """Collate dense tensors""" + + if not samples: + return None + if len(set(x.ndim for x in samples)) != 1: + raise RuntimeError( + f"Samples has varying dimensions: {[x.dim() for x in samples]}" + ) + max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] + + result = np.zeros((len(samples), *max_shape), np.float32) + + for i, x_sample in enumerate(samples): + len_sample = x_sample.shape[0] + result[i][len_sample:] = pad_v + result[i][:len_sample] = x_sample + + return result + + def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None): + batch_size = len(coords_list) + if confidence_list is None: + confidence_list = [None] * batch_size + if seq_list is None: + seq_list = [None] * batch_size + raw_batch = zip(coords_list, confidence_list, seq_list) + return self.__call__(raw_batch, device) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/nn_arch.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/nn_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..be34cb4b7be901b90853bbe283f8ca41469fbccc --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/esm/nn_arch.py @@ -0,0 +1,117 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""gvp transformer model""" + +import mindspore as ms +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore.common.initializer import Normal, initializer, Constant +from .module.transformer_encoder import GVPTransformerEncoder +from .module.transformer_decoder import TransformerDecoder +from .module.util import CoordBatchConverter + + +class GVPTransformerModel(nn.Cell): + """GVP transformer model""" + + def __init__(self, args, alphabet): + super(GVPTransformerModel, self).__init__() + encoder_embed_tokens = self.build_embedding( + alphabet, args.encoder_embed_dim, + ) + decoder_embed_tokens = self.build_embedding( + alphabet, args.decoder_embed_dim, + ) + encoder = self.build_encoder(args, alphabet, encoder_embed_tokens) + decoder = self.build_decoder(args, alphabet, decoder_embed_tokens) + self.args = args + self.encoder = encoder + self.decoder = decoder + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + encoder = GVPTransformerEncoder(args, src_dict, embed_tokens) + return encoder + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + decoder = TransformerDecoder( + args, + tgt_dict, + embed_tokens, + ) + return decoder + + @classmethod + def build_embedding(cls, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.padding_idx + emb = nn.Embedding(num_embeddings, embed_dim, padding_idx=padding_idx) + emb.embedding_table = initializer(Normal(mean=0, sigma=embed_dim ** -0.5), emb.embedding_table.shape, + dtype=ms.float32) + Constant(0)(emb.embedding_table[padding_idx]) + return emb + + def construct(self, net_input): + """Transformer construction""" + + coords, padding_mask, confidence, prev_output_tokens = net_input + return_all_hiddens: bool = False + features_only: bool = False + encoder_out = self.encoder(coords, padding_mask, confidence, return_all_hiddens=return_all_hiddens) + logits, _ = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + return_all_hiddens=return_all_hiddens, + ) + return logits + + def sample(self, coords, temperature=1.0, confidence=None): + """Sample sequence designs for a given structure""" + + l_coords = len(coords) + # Convert to batch format + batch_converter = CoordBatchConverter(self.decoder.dictionary) + batch_coords, confidence, _, _, padding_mask = ( + batch_converter([(coords, confidence, None)]) + ) + + # Start with prepend token + sampled_tokens = ops.Zeros()((1, 1 + l_coords), ms.float32) + sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('') + + # Save incremental states for faster sampling + incremental_state = dict() + + # Run encoder only once + encoder_out = self.encoder(batch_coords, padding_mask, confidence) + + # Decode one token at a time + for i in range(1, l_coords + 1): + logits, _ = self.decoder(sampled_tokens[:, :i], encoder_out, incremental_state=incremental_state) + logits = logits[0].reshape(1, -1) + logits /= temperature + softmax = ops.Softmax(axis=-1) + probs = softmax(logits) + probs = probs.reshape(1, -1) + tokens = ops.Argmax()(probs) + sampled_tokens[:, i] = tokens + sampled_seq = sampled_tokens[0, 1:] + sampled_seq = ops.Cast()(sampled_seq, ms.int32) + + # Convert back to string via lookup + output = ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq]) + return output diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94824050f12e5eeaadf30fd8c81df7743d8b4205 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" + +from .megaassessment import MEGAAssessment +from .megaassessment_dataset import MEGAAssessmentDataSet +from .megaassessment_configuration import megaassessment_configuration diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megaassessment_configuration.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megaassessment_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..47e684b4ddf8bb8b9818a80c41176166e99d9b9e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megaassessment_configuration.py @@ -0,0 +1,28 @@ +# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""megaassessment_configuration""" + +megaassessment_configuration = { + "training": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/initial_train.yaml", + "predict_256": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_256.yaml", + "predict_512": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_512.yaml", + "predict_768": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_768.yaml", + "predict_1024": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_1024.yaml", + "predict_1280": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_1280.yaml", + "predict_1536": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_1536.yaml", + "predict_1792": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_1792.yaml", + "predict_2048": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_2048.yaml", + "predict_2304": "https://download.mindspore.cn/mindscience/mindsponge/MEGAFold/config/predict_2304.yaml", +} diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megassessment.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megassessment.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e40e2b62de2c8d16849250698af33a31ad68bb --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megassessment.py @@ -0,0 +1,196 @@ +# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""megaassessment""" + +import time +import os +import ssl +import urllib +import numpy as np +from mindspore import jit, context, nn, load_param_into_net, Tensor +from mindspore.common import mutable +from .module.assessment_wrapcell import TrainOneStepCell, WithLossCell +from .nn_arch import CombineModel as megaassessment +from .nn_arch import load_weights +from ..model import Model + + +class MEGAAssessment(Model): + """megaassessment model""" + name = "MEGAssessment" + feature_list = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', + 'template_aatype', 'template_all_atom_masks', 'template_all_atom_positions', + 'template_mask', 'template_pseudo_beta_mask', 'template_pseudo_beta', 'extra_msa', + 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index', + 'prev_pos', 'prev_msa_first_row', 'prev_pair', 'decoy_atom_positions', 'decoy_atom_mask'] + + label_list = ["pseudo_beta", "pseudo_beta_mask", "all_atom_mask", "true_msa", "bert_mask", + "residx_atom14_to_atom37", "restype_atom14_bond_lower_bound", "restype_atom14_bond_upper_bound", + "atomtype_radius", "backbone_affine_tensor", "backbone_affine_mask", "atom14_gt_positions", + "atom14_alt_gt_positions", "atom14_atom_is_ambiguous", "atom14_gt_exists", "atom14_atom_exists", + "atom14_alt_gt_exists", "all_atom_positions", "rigidgroups_gt_frames", "rigidgroups_gt_exists", + "rigidgroups_alt_gt_frames", "torsion_angles_sin_cos", "chi_mask"] + + def __init__(self, config): + context.set_context(memory_optimize_level="O1", max_call_depth=6000) + if context.get_context("device_target") == "GPU": + self.mixed_precision = False + context.set_context(graph_kernel_flags="--disable_expand_ops=Softmax --disable_cluster_ops=ReduceSum " + "--composite_op_limit_size=50", enable_graph_kernel=True) + else: + self.mixed_precision = True + + self.config = config + self.use_jit = self.config.use_jit + self.use_jit = True + self.network = megaassessment(self.config, self.mixed_precision) + if self.config.is_training: + self.checkpoint_url = 'https://download.mindspore.cn/mindscience/mindsponge/' \ + 'MEGAFold/checkpoint/MEGA_Fold_1.ckpt' + self.checkpoint_path = "./MEGA_Fold_1.ckpt" + if not os.path.exists(self.checkpoint_path): + print("Download checkpoint to ", self.checkpoint_path) + # pylint: disable=protected-access + ssl._create_default_https_context = ssl._create_unverified_context + urllib.request.urlretrieve(self.checkpoint_url, self.checkpoint_path) + param_dict = load_weights(self.checkpoint_path, self.config.model) + load_param_into_net(self.network, param_dict) + else: + self.checkpoint_url = 'https://download.mindspore.cn/mindscience/mindsponge/' \ + 'MEGAAssessment/checkpoint/MEGA_Assessment.ckpt' + self.checkpoint_path = "./MEGA_Assessment.ckpt" + net_with_criterion = WithLossCell(self.network, self.config) + lr = 0.0001 + opt = nn.Adam(params=self.network.trainable_params(), learning_rate=lr, eps=1e-6) + self.train_net = TrainOneStepCell(net_with_criterion, opt, sens=1, gradient_clip_value=0.1) + super().__init__(self.checkpoint_url, self.network, self.name) + + # pylint: disable=arguments-differ + def forward(self, data, run_pretrain=True): + """forward""" + if self.use_jit: + outputs = self._jit_forward(data, run_pretrain=run_pretrain) + else: + outputs = self._pynative_forward(data, run_pretrain=run_pretrain) + return outputs + + # pylint: disable=arguments-differ + def predict(self, inputs): + """predict""" + recycle_feature_name = self.feature_list[:-5] + prev_pos = Tensor(inputs['prev_pos']) + prev_msa_first_row = Tensor(inputs['prev_msa_first_row']) + prev_pair = Tensor(inputs['prev_pair']) + data = {} + for recycle in range(4): + for key in recycle_feature_name: + data[key] = Tensor(inputs[key][recycle]) + data['prev_pos'] = prev_pos + data['prev_msa_first_row'] = prev_msa_first_row + data['prev_pair'] = prev_pair + data = mutable(data) + t1 = time.time() + prev_pos, prev_msa_first_row, prev_pair, _ = self.forward(data, run_pretrain=True) + t2 = time.time() + print(round(t2 - t1, 2)) + data['prev_pos'] = prev_pos + data['prev_msa_first_row'] = prev_msa_first_row + data['prev_pair'] = prev_pair + data['decoy_atom_positions'] = Tensor(inputs['decoy_atom_positions']) + data['decoy_atom_mask'] = Tensor(inputs['decoy_atom_mask']) + + plddt = self.forward(data, run_pretrain=False) + plddt = plddt.asnumpy()[inputs['align_mask'] == 1] + return plddt + + # pylint: disable=arguments-differ + @jit + def backward(self, feat): + """backward""" + loss = self.train_net(*feat) + return loss + + # pylint: disable=arguments-differ + def train_step(self, data): + """train one step""" + num_recycle = np.random.randint(low=1, high=5) + self.train_net.add_flags_recursive(train_backward=False) + self.train_net.phase = 'train_forward' + recycle_feature_name = self.feature_list[:-5] + prev_pos = Tensor(data['prev_pos']) + prev_msa_first_row = Tensor(data['prev_msa_first_row']) + prev_pair = Tensor(data['prev_pair']) + for recycle in range(4): + inputs = {} + for key in recycle_feature_name: + inputs[key] = Tensor(data[key][recycle]) + inputs['prev_pos'] = prev_pos + inputs['prev_msa_first_row'] = prev_msa_first_row + inputs['prev_pair'] = prev_pair + inputs = mutable(inputs) + t1 = time.time() + prev_pos, prev_msa_first_row, prev_pair, _ = self.forward(inputs, run_pretrain=True) + if recycle == num_recycle: + final_atom_positions_recycle = prev_pos + t2 = time.time() + print("forward time : ", round(t2 - t1, 2)) + inputs = {} + for key in self.feature_list[:-5]: + inputs[key] = Tensor(data[key][num_recycle - 1]) + inputs['prev_pos'] = prev_pos + inputs['prev_msa_first_row'] = prev_msa_first_row + inputs['prev_pair'] = prev_pair + for key in self.label_list: + inputs[key] = Tensor(data[key]) + self.train_net.add_flags_recursive(train_backward=True) + self.train_net.phase = 'train_backward' + keys = self.feature_list[:-2] + self.label_list + feat = [] + for key in keys: + feat.append(inputs.get(key)) + feat.append(final_atom_positions_recycle) + feat.append(inputs.get('atom37_atom_exists')) + feat = mutable(feat) + t1 = time.time() + loss = self.backward(feat) + t2 = time.time() + print("backward time : ", round(t2 - t1, 2)) + return loss + + # pylint: disable=arguments-differ + @jit + def _jit_forward(self, data, run_pretrain=True): + """forward with jit mode""" + feat = [] + feature_list = self.feature_list + if run_pretrain: + feature_list = self.feature_list[:-2] + for key in feature_list: + feat.append(data[key]) + outputs = self.network(*feat, run_pretrain=run_pretrain) + return outputs + + # pylint: disable=arguments-differ + def _pynative_forward(self, data, run_pretrain=True): + """forward with pynative mode""" + feat = [] + feature_list = self.feature_list + if run_pretrain: + feature_list = self.feature_list[:-2] + for key in feature_list: + feat.append(data[key]) + outputs = self.network(*feat, run_pretrain=run_pretrain) + return outputs diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megassessment_dataset.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megassessment_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c011ddc0ecc05b8cea07560f140581fbacb480dc --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/megassessment_dataset.py @@ -0,0 +1,85 @@ +# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""megaassessment dataset""" +import numpy as np +from mindsponge.common.residue_constants import order_restype_with_x +from mindsponge.common.utils import get_aligned_seq + +from ..megafold.megafold_dataset import MEGAFoldDataSet + + +class MEGAAssessmentDataSet(MEGAFoldDataSet): + """megasssessment dataset""" + def __init__(self, config, seed=0): + self.config = config + self.supported_models = ['MEGAAssessment'] + super().__init__(self.config, seed) + + def process(self, data, label=None, ensemble_num=4): + features = super().process(data, label, ensemble_num) + if not label: + features = self.process_pdb(features, data) + return features + + def align_with_aatype(self, true_aatype, aatype, atom37_positions, atom37_mask): + """align pdb with aatype""" + if len(true_aatype) == len(aatype): + out = aatype, atom37_positions, atom37_mask, np.ones((aatype.shape[0])).astype(np.float32) + return out + seq1 = [order_restype_with_x.get(x) for x in aatype] + seq2 = [order_restype_with_x.get(x) for x in true_aatype] + seq1 = ''.join(seq1) + seq2 = ''.join(seq2) + _, align_relationship, _ = get_aligned_seq(seq1, seq2) + pdb_index = 0 + seq_len = len(true_aatype) + new_aatype = np.zeros((seq_len,)).astype(np.int32) + new_atom37_positions = np.zeros((seq_len, 37, 3)).astype(np.float32) + new_atom37_mask = np.zeros((seq_len, 37)).astype(np.float32) + align_mask = np.zeros((seq_len,)).astype(np.float32) + for i in range(len(true_aatype)): + if align_relationship[i] == "-": + new_aatype[i] = 20 + new_atom37_positions[i] = np.zeros((37, 3)).astype(np.float32) + new_atom37_mask[i] = np.zeros((37,)).astype(np.float32) + align_mask[i] = 0 + else: + new_aatype[i] = aatype[pdb_index] + new_atom37_positions[i] = atom37_positions[pdb_index] + new_atom37_mask[i] = atom37_mask[pdb_index] + align_mask[i] = 1 + pdb_index += 1 + out = new_aatype, new_atom37_positions, new_atom37_mask, align_mask + return out + + def process_pdb(self, features, data): + """get atom information from pdb""" + decoy_aatype = data["decoy_aatype"] + decoy_atom37_positions = data["decoy_atom_positions"].astype(np.float32) + decoy_atom37_mask = data["decoy_atom_mask"].astype(np.float32) + ori_res_length = data['msa'].shape[1] + padding_val = features["aatype"][0].shape[0] - ori_res_length + true_aatype = features["aatype"][0][:ori_res_length] + decoy_aatype, decoy_atom37_positions, decoy_atom37_mask, align_mask = \ + self.align_with_aatype(true_aatype, decoy_aatype, decoy_atom37_positions, decoy_atom37_mask) + decoy_atom37_positions = np.pad(decoy_atom37_positions, ((0, padding_val), (0, 0), (0, 0))) + decoy_atom37_mask = np.pad(decoy_atom37_mask, ((0, padding_val), (0, 0))) + align_mask = np.pad(align_mask, (0, padding_val)) + + features["decoy_atom_positions"] = decoy_atom37_positions + features["decoy_atom_mask"] = decoy_atom37_mask + features["align_mask"] = align_mask + + return features diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/assessment_wrapcell.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/assessment_wrapcell.py new file mode 100644 index 0000000000000000000000000000000000000000..ed5fc07c3cf3afa6e4828eb3e92dc790d730f1f8 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/assessment_wrapcell.py @@ -0,0 +1,163 @@ +# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""warp cell""" + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import ops +from mindspore.context import ParallelMode +from mindspore.nn import DistributedGradReducer +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.parallel._utils import _get_device_num +from mindspore.parallel._utils import (_get_gradients_mean, _get_parallel_mode) +# pylint: disable=relative-beyond-top-level +from .loss_module import LossNetAssessment as LossNet + +GRADIENT_CLIP_TYPE = 1 + +clip_grad = ops.MultitypeFuncGraph("clip_grad") + + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """_clip_grad""" + if clip_type not in (0, 1): + return grad + dt = ops.dtype(grad) + if clip_type == 0: + new_grad = ops.clip_by_value(grad, ops.cast(ops.tuple_to_array((-clip_value,)), dt), + ops.cast(ops.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt)) + return new_grad + + +grad_scale = C.MultitypeFuncGraph("grad_scale") + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + """tensor_grad_scale""" + return grad * ops.Reciprocal()(scale) + + +class TrainOneStepCell(nn.Cell): + """TrainOneStepCell""" + def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True, use_global_norm=True, + gradient_clip_value=1.0): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.optimizer = optimizer + self.weights = self.optimizer.parameters + self.grad = ops.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.enable_clip_grad = enable_clip_grad + self.hyper_map = ops.HyperMap() + self.use_global_norm = use_global_norm + self.gradient_clip_value = gradient_clip_value + + self.reducer_flag = False + self.grad_reducer = F.identity + self.parallel_mode = _get_parallel_mode() + self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) + if self.reducer_flag: + self.mean = _get_gradients_mean() + self.degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) + + def construct(self, *inputs): + """construct""" + if self.train_backward: + loss_all = self.network(*inputs) + loss, l_fape_side, l_fape_backbone, l_anglenorm, predict_lddt_loss, \ + distogram_focal_loss, distogram_regression_loss, plddt_dist_loss, mask_loss, \ + confidence_loss, cameo_loss = loss_all + sens = F.fill(loss.dtype, loss.shape, self.sens) + sens1 = F.fill(l_fape_side.dtype, l_fape_side.shape, 0.0) + sens2 = F.fill(l_fape_backbone.dtype, l_fape_backbone.shape, 0.0) + sens3 = F.fill(l_anglenorm.dtype, l_anglenorm.shape, 0.0) + sens4 = F.fill(predict_lddt_loss.dtype, predict_lddt_loss.shape, 0.0) + sens5 = F.fill(distogram_focal_loss.dtype, distogram_focal_loss.shape, 0.0) + sens6 = F.fill(distogram_regression_loss.dtype, distogram_regression_loss.shape, 0.0) + sens7 = F.fill(plddt_dist_loss.dtype, plddt_dist_loss.shape, 0.0) + sens8 = F.fill(mask_loss.dtype, mask_loss.shape, 0.0) + sens9 = F.fill(confidence_loss.dtype, confidence_loss.shape, 0.0) + sens10 = F.fill(cameo_loss.dtype, cameo_loss.shape, 0.0) + grads = self.grad(self.network, self.weights)(*inputs, (sens, sens1, sens2, sens3, sens4, + sens5, sens6, sens7, sens8, sens9, sens10)) + grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_tensor(self.sens)), grads) + grads = self.grad_reducer(grads) + if self.enable_clip_grad: + if self.use_global_norm: + grads = C.clip_by_global_norm(grads, self.gradient_clip_value) + else: + grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, self.gradient_clip_value), grads) + + loss_all = F.depend(loss_all, self.optimizer(grads)) + return loss_all + + out = self.network(*inputs) + return out + + +class WithLossCell(nn.Cell): + """WithLossCell""" + def __init__(self, backbone, config): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self.loss_net = LossNet(config).to_float(mstype.float32) + + def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair, + pseudo_beta_gt, pseudo_beta_mask_gt, all_atom_mask_gt, true_msa, bert_mask, + residx_atom14_to_atom37, restype_atom14_bond_lower_bound, restype_atom14_bond_upper_bound, + atomtype_radius, backbone_affine_tensor, backbone_affine_mask, + atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, + atom14_atom_exists, atom14_alt_gt_exists, all_atom_positions, rigidgroups_gt_frames, + rigidgroups_gt_exists, rigidgroups_alt_gt_frames, torsion_angles_sin_cos_gt, chi_mask, + decoy_atom_positions, decoy_atom_mask): + """construct""" + dist_logits, bin_edges, atom14_pred_positions, final_affines, angles_sin_cos_new, \ + predicted_lddt_logits, structure_traj, sidechain_frames, sidechain_atom_pos, \ + um_angles_sin_cos_new, final_atom_positions, decoy_pseudo_beta, \ + decoy_pseudo_beta_mask, decoy_logits, plddt_dist, pred_mask2d = \ + self._backbone(target_feat, msa_feat, msa_mask, seq_mask, aatype, template_aatype, + template_all_atom_masks, template_all_atom_positions, template_mask, + template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, residx_atom37_to_atom14, atom37_atom_exists, + residue_index, prev_pos, prev_msa_first_row, prev_pair, decoy_atom_positions, + decoy_atom_mask, run_pretrain=False) + out = self.loss_net(dist_logits, bin_edges, pseudo_beta_gt, pseudo_beta_mask_gt, + atom37_atom_exists, all_atom_mask_gt, true_msa, + bert_mask, atom14_pred_positions, residue_index, aatype, + residx_atom14_to_atom37, restype_atom14_bond_lower_bound, + restype_atom14_bond_upper_bound, seq_mask, atomtype_radius, final_affines, + angles_sin_cos_new, um_angles_sin_cos_new, backbone_affine_tensor, backbone_affine_mask, + atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, + atom14_gt_exists, atom14_atom_exists, atom14_alt_gt_exists, + final_atom_positions, all_atom_positions, predicted_lddt_logits, + structure_traj, rigidgroups_gt_frames, rigidgroups_gt_exists, + rigidgroups_alt_gt_frames, + sidechain_frames, sidechain_atom_pos, torsion_angles_sin_cos_gt, + chi_mask, decoy_pseudo_beta, decoy_pseudo_beta_mask, decoy_logits, + plddt_dist, pred_mask2d) + + return out diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/head.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/head.py new file mode 100644 index 0000000000000000000000000000000000000000..36e79f092c9e5169906c6a7c24cd959d6ad934ad --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/head.py @@ -0,0 +1,249 @@ +# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" + +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore import Tensor +from mindspore.ops import functional as F +from mindsponge.pipeline.cell.initializer import lecun_init + + +class PredictedLDDTHead(nn.Cell): + """Head to predict the per-residue LDDT to be used as a confidence measure.""" + + def __init__(self, config, seq_channel): + super().__init__() + self.config = config + self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5) + self.act_0 = nn.Dense(seq_channel, self.config.num_channels, + weight_init=lecun_init(seq_channel, initializer_name='relu') + ).to_float(mstype.float16) + self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, + weight_init=lecun_init(self.config.num_channels, initializer_name='relu') + ).to_float(mstype.float16) + self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros' + ).to_float(mstype.float16) + self.relu = nn.ReLU() + + def construct(self, rp_structure_module): + """Builds ExperimentallyResolvedHead module.""" + act = rp_structure_module + act = self.input_layer_norm(act.astype(mstype.float32)) + act = self.act_0(act) + act = self.relu(act.astype(mstype.float32)) + act = self.act_1(act) + act = self.relu(act.astype(mstype.float32)) + logits = self.logits(act) + return logits + + +class DistogramHead(nn.Cell): + """Head to predict a distogram. + + Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" + """ + + def __init__(self, config, pair_dim): + super().__init__() + self.config = config + self.half_logits = nn.Dense(pair_dim, self.config.num_bins, weight_init='zeros') + self.first_break = self.config.first_break + self.last_break = self.config.last_break + self.num_bins = self.config.num_bins + + def construct(self, pair): + """Builds DistogramHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + + Returns: + Dictionary containing: + * logits: logits for distogram, shape [N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1,]. + """ + half_logits = self.half_logits(pair) + + logits = half_logits + mnp.swapaxes(half_logits, -2, -3) + breaks = mnp.linspace(self.first_break, self.last_break, self.num_bins - 1) + + return logits, breaks + + +class ExperimentallyResolvedHead(nn.Cell): + """Predicts if an atom is experimentally resolved in a high-res structure. + + Only trained on high-resolution X-ray crystals & cryo-EM. + Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' + """ + + def __init__(self, seq_channel): + super().__init__() + self.logits = nn.Dense(seq_channel, 37, weight_init='zeros') + + def construct(self, single): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'single': Single representation, shape [N_res, c_s]. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_res, 37], + log probability that an atom is resolved in atom37 representation, + can be converted to probability by applying sigmoid. + """ + logits = self.logits(single) + return logits + + +class MaskedMsaHead(nn.Cell): + """Head to predict MSA at the masked locations. + + The MaskedMsaHead employs a BERT-style objective to reconstruct a masked + version of the full MSA, based on a linear projection of + the MSA representation. + Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" + """ + + def __init__(self, config, msa_channel): + super().__init__() + self.config = config + self.logits = nn.Dense(msa_channel, self.config.num_output, weight_init='zeros') + + def construct(self, msa): + """Builds MaskedMsaHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'msa': MSA representation, shape [N_seq, N_res, c_m]. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_seq, N_res, N_aatype] with + (unnormalized) log probabilies of predicted aatype at position. + """ + # del batch + logits = self.logits(msa) + return logits + + +class PredictedAlignedErrorHead(nn.Cell): + """Head to predict the distance errors in the backbone alignment frames. + + Can be used to compute predicted TM-Score. + Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" + """ + + def __init__(self, config, pair_dim): + super().__init__() + self.config = config + self.num_bins = self.config.num_bins + self.max_error_bin = self.config.max_error_bin + self.logits = nn.Dense(pair_dim, self.num_bins, weight_init='zeros') + + def construct(self, pair): + """Builds PredictedAlignedErrorHead module. + + Arguments: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + + Returns: + * logits: logits for aligned error, shape [N_res, N_res, N_bins]. + * breaks: array containing bin breaks, shape [N_bins - 1]. + """ + logits = self.logits(pair) + breaks = mnp.linspace(0, self.max_error_bin, self.num_bins - 1) + return logits, breaks + + +class EstogramHead(nn.Cell): + """Head to predict estogram.""" + + def __init__(self, first_break, last_break, num_bins): + super().__init__() + self.first_break = first_break + self.last_break = last_break + self.num_bins = num_bins + + self.breaks = mnp.linspace(self.first_break, self.last_break, self.num_bins) + self.width = self.breaks[1] - self.breaks[0] + + self.centers = self.breaks + 0.5 * self.width + + self.softmax = nn.Softmax(-1) + self.zero = Tensor([0.]) + + def compute_estogram(self, distogram_logits, decoy_distance_mat): + """compute estogram matrix. + Arguments: + distogram_logits: [N_res, N_res, N_bins]. + decoy_distance_mat: [N_res, N_res] + Returns: + estogram: shape [N_res, N_res, N_bins]. + esto_centers: shape [N_res, N_res, N_bins]. + """ + square_centers = mnp.reshape(self.centers, (1, 1, -1)) + estogram = self.softmax(distogram_logits) + esto_centers = square_centers - mnp.expand_dims(decoy_distance_mat, -1) + return estogram, esto_centers + + def construct(self, distogram_logits, pseudo_beta, pseudo_beta_mask, cutoff=15.): + """construct""" + positions = pseudo_beta + pad_mask = mnp.expand_dims(pseudo_beta_mask, 1) + pad_mask_2d = pad_mask * mnp.transpose(pad_mask, (1, 0)) + pad_mask_2d *= (1. - mnp.eye(pad_mask_2d.shape[1])) + + dist_xyz = mnp.square(mnp.expand_dims(positions, axis=1) - mnp.expand_dims(positions, axis=0)) + dmat_decoy = mnp.sqrt(1e-10 + mnp.sum(dist_xyz.astype(mstype.float32), -1)) + + estogram, esto_centers = self.compute_estogram(distogram_logits, dmat_decoy) + pair_errors = mnp.sum(estogram * esto_centers, -1) + + p1 = self._integrate(distogram_logits, mnp.abs(esto_centers) < 0.5).astype(mnp.float32) + p2 = self._integrate(distogram_logits, mnp.abs(esto_centers) < 1.0).astype(mnp.float32) + p3 = self._integrate(distogram_logits, mnp.abs(esto_centers) < 2.0).astype(mnp.float32) + p4 = self._integrate(distogram_logits, mnp.abs(esto_centers) < 4.0).astype(mnp.float32) + + p0 = self._integrate(distogram_logits, self.centers < cutoff).astype(mnp.float32) + pred_mask2d = p0 * pad_mask_2d + + norm = mnp.sum(pred_mask2d, -1) + 1e-6 + p1 = mnp.sum(p1 * pred_mask2d, -1) + p2 = mnp.sum(p2 * pred_mask2d, -1) + p3 = mnp.sum(p3 * pred_mask2d, -1) + p4 = mnp.sum(p4 * pred_mask2d, -1) + + plddt = 0.25 * (p1 + p2 + p3 + p4) / norm + + return plddt, pred_mask2d, pair_errors + + def _integrate(self, distogram_logits, integrate_masks): + """compute estogram matrix. + Arguments: + distogram_logits: [N_res, N_res, N_bins]. + integrate_masks: [N_res, N_res, N_bins] + Returns: + v: shape [N_res, N_res]. + """ + probs = self.softmax(distogram_logits) + integrate_masks = F.cast(integrate_masks, mnp.float32) + v = mnp.sum(probs * integrate_masks, -1) + return v diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/loss_module.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/loss_module.py new file mode 100644 index 0000000000000000000000000000000000000000..a3acfbc6b670000585f684b726cc081f21215d1b --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/module/loss_module.py @@ -0,0 +1,244 @@ +# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""loss module""" + +import mindspore as ms +import mindspore.communication.management as D +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.context import ParallelMode +from mindspore.parallel._utils import _get_parallel_mode +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindsponge.metrics.structure_violations import local_distance_difference_test +from mindsponge.metrics import BalancedMSE, BinaryFocal, MultiClassFocal +# pylint: disable=relative-beyond-top-level +from ...megafold.module.loss_module import LossNet + + +class LossNetAssessment(nn.Cell): + """loss net""" + + def __init__(self, config): + super(LossNetAssessment, self).__init__() + self.orign_loss = LossNet(config, train_fold=False) + self.num_bins = config.model.heads.distogram.num_bins + self.cutoff = 15.0 + self.within_cutoff_clip = 0.3 + self.beyond_cutoff_clip = 3.0 + self.beyond_cutoff_weight = 0.2 + self.regressor_idx = 1 + self.regressor_weight = 2. + self.parallel_mode = _get_parallel_mode() + self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) + if self.reducer_flag: + self.allreduce = P.AllReduce() + self.device_num = D.get_group_size() + + self.reg_loss_distogram = RegressionLosses(first_break=2., last_break=22., + num_bins=config.model.heads.distogram.num_bins, bin_shift=True, + charbonnier_eps=0.1, reducer_flag=self.reducer_flag) + self.reg_loss_lddt = RegressionLosses(first_break=0., last_break=1., + num_bins=config.model.heads.predicted_lddt.num_bins, bin_shift=False, + charbonnier_eps=1e-5, reducer_flag=self.reducer_flag) + + self.binary_focal_loss = BinaryFocal(alpha=0.25, gamma=1., feed_in=False, not_focal=False) + self.softmax_focal_loss_lddt = MultiClassFocal(num_class=config.model.heads.predicted_lddt.num_bins, + gamma=1., e=0.1, neighbors=2, not_focal=False, + reducer_flag=self.reducer_flag) + self.softmax_focal_loss_distogram = MultiClassFocal(num_class=config.model.heads.distogram.num_bins, + gamma=1., e=0.1, neighbors=2, not_focal=False, + reducer_flag=self.reducer_flag) + self.cameo_focal_loss = BinaryFocal(alpha=0.2, gamma=0.5, feed_in=True, not_focal=False) + self.distogram_one_hot = nn.OneHot(depth=self.num_bins, axis=-1) + self.breaks = mnp.linspace(2.0, 22.0, self.num_bins) + self.width = self.breaks[1] - self.breaks[0] + self.centers = self.breaks + 0.5 * self.width + + def distogram_loss(self, logits, bin_edges, pseudo_beta, pseudo_beta_mask): + """Log loss of a distogram.""" + positions = pseudo_beta + mask = pseudo_beta_mask + + sq_breaks = mnp.square(bin_edges) + dist_t = mnp.square(mnp.expand_dims(positions, axis=-2) - mnp.expand_dims(positions, axis=-3)) + dist2 = P.ReduceSum(True)(dist_t.astype(ms.float32), -1) + aa = (dist2 > sq_breaks).astype(ms.float32) + + square_mask = mnp.expand_dims(mask, axis=-2) * mnp.expand_dims(mask, axis=-1) + probs = nn.Softmax(-1)(logits) + dmat_pred = mnp.sum(probs * mnp.reshape(self.centers, (1, 1, -1)), -1) + dist2 = dist2[..., 0] + dmat_true = mnp.sqrt(1e-6 + dist2) + + within_cutoff_mask = F.cast(dmat_true < self.cutoff, mnp.float32) + within_cutoff_mask *= (1. - mnp.eye(within_cutoff_mask.shape[1])) + beyond_cutoff_mask = F.cast(dmat_true > self.cutoff, mnp.float32) + beyond_cutoff_mask *= self.beyond_cutoff_weight + + true_bins = P.ReduceSum()(aa, -1) + true_bins = true_bins.astype(ms.int32) + + nres, nres, nbins = logits.shape + logits = mnp.reshape(logits, (-1, nbins)) + labels = self.distogram_one_hot(true_bins) + labels = mnp.reshape(labels, (-1, nbins)) + + error = self.softmax_focal_loss_distogram(logits, labels) + error = mnp.reshape(error, (nres, nres)) + focal_error = within_cutoff_mask * error + beyond_cutoff_mask * error + + focal_loss = (P.ReduceSum()(focal_error * square_mask, (-2, -1)) / + (1e-6 + P.ReduceSum()(square_mask.astype(ms.float32), (-2, -1)))) + + error_tuple = self.reg_loss_distogram(dmat_pred, dmat_true) + regression_error = error_tuple[1] + + regression_error_clip_within = mnp.clip(regression_error, self.within_cutoff_clip, + 20.) - self.within_cutoff_clip + regression_error_clip_beyond = mnp.clip(regression_error, self.beyond_cutoff_clip, + 20.) - self.beyond_cutoff_clip + + regression_error = regression_error_clip_within * within_cutoff_mask + regression_error_clip_beyond \ + * beyond_cutoff_mask + + square_mask_off_diagonal = square_mask * (1 - mnp.eye(square_mask.shape[1])) + + regression_loss = (P.ReduceSum()(regression_error * square_mask_off_diagonal, (-2, -1)) / + (1e-6 + P.ReduceSum()(square_mask_off_diagonal.astype(ms.float32), (-2, -1)))) + + loss = focal_loss + self.regressor_weight * regression_loss + + dist_loss = loss, focal_loss, regression_loss, dmat_true + + return dist_loss + + def construct(self, distogram_logits, bin_edges, pseudo_beta, pseudo_beta_mask, + atom37_atom_exists, all_atom_mask, true_msa, bert_mask, + final_atom14_positions, residue_index, aatype, residx_atom14_to_atom37, lower_bound, upper_bound, + seq_mask, atomtype_radius, final_affines, angles_sin_cos, + um_angles_sin_cos, backbone_affine_tensor, backbone_affine_mask, atom14_gt_positions, + atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, atom14_atom_exists, + atom14_alt_gt_exists, final_atom_positions, all_atom_positions, predicted_lddt_logits, traj, + rigidgroups_gt_frames, rigidgroups_gt_exists, rigidgroups_alt_gt_frames, + pred_frames, pred_positions, sin_cos_true_chi, torsion_angle_mask, + decoy_pseudo_beta, decoy_pseudo_beta_mask, decoy_predicted_lddt_logits, plddt_dist, pred_mask2d): + """construct""" + _, l_fape_side, l_fape_backbone, l_anglenorm, _, _, predict_lddt_loss = self.orign_loss( + distogram_logits, bin_edges, pseudo_beta, pseudo_beta_mask, None, + atom37_atom_exists, all_atom_mask, true_msa, None, bert_mask, + final_atom14_positions, residue_index, aatype, residx_atom14_to_atom37, lower_bound, upper_bound, + seq_mask, atomtype_radius, final_affines, None, None, angles_sin_cos, + um_angles_sin_cos, backbone_affine_tensor, backbone_affine_mask, atom14_gt_positions, + atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, atom14_atom_exists, + atom14_alt_gt_exists, final_atom_positions, all_atom_positions, predicted_lddt_logits, traj, + rigidgroups_gt_frames, rigidgroups_gt_exists, rigidgroups_alt_gt_frames, + pred_frames, pred_positions, sin_cos_true_chi, torsion_angle_mask, 1.0, 1.0) + + fold_loss = l_fape_side + l_fape_backbone + l_anglenorm + predict_lddt_loss + + lddt_cb = local_distance_difference_test( + predicted_points=decoy_pseudo_beta[None, ...], + true_points=pseudo_beta[None, ...], + true_points_mask=decoy_pseudo_beta_mask[None, ..., None].astype(mnp.float32), + cutoff=15., + per_residue=True)[0] + lddt_cb = F.stop_gradient(lddt_cb) + + distogram_loss, distogram_focal_loss, distogram_regression_loss, dmat_true = self.distogram_loss( + distogram_logits, bin_edges, pseudo_beta, pseudo_beta_mask) + + mask1d = decoy_pseudo_beta_mask + mask2d = mnp.expand_dims(mask1d, 1) * mnp.expand_dims(mask1d, 0) + error_tuple = self.reg_loss_lddt(plddt_dist, lddt_cb) + plddt2_error = error_tuple[self.regressor_idx] + + plddt2_regression_loss = mnp.sum(plddt2_error * mask1d) / (mnp.sum(mask1d) + 1e-8) + plddt2_loss = self.regressor_weight * plddt2_regression_loss + + true_mask2d = P.Cast()(dmat_true < self.cutoff, ms.float32) + mask_error = self.binary_focal_loss(mnp.reshape(pred_mask2d, (-1,)), mnp.reshape(true_mask2d, (-1,))) + mask_error = mnp.reshape(mask_error, true_mask2d.shape) + mask_loss = mnp.sum(mask_error * mask2d) / (mnp.sum(mask2d) + 1e-6) + confidence_pred = mnp.sum(plddt_dist * mask1d) / (mnp.sum(mask1d) + 1e-6) + confidence_gt = mnp.sum(lddt_cb * mask1d) / (mnp.sum(mask1d) + 1e-6) + confidence_loss = nn.MSELoss()(confidence_pred, confidence_gt) + confidence_loss = mnp.sqrt(confidence_loss + 1e-5) + + cameo_label = F.cast(lddt_cb < 0.6, mnp.float32) + cameo_scale = decoy_predicted_lddt_logits[:, 0] + cameo_shift = decoy_predicted_lddt_logits[:, 1] + cameo_scale = 5. * P.Tanh()(cameo_scale / 5.) + decoy_cameo_logit = -F.exp(cameo_scale) * (plddt_dist + cameo_shift - 0.6) + cameo_error = self.cameo_focal_loss(decoy_cameo_logit, cameo_label) + cameo_loss = mnp.sum(cameo_error * mask1d) / (mnp.sum(mask1d) + 1e-6) + + score_loss = distogram_loss + plddt2_loss + mask_loss + 2.0 * confidence_loss + 2.0 * cameo_loss + + loss = 0.5 * fold_loss + score_loss + + seq_len = F.cast(P.ReduceSum()(pseudo_beta_mask), mnp.float32) + loss_weight = mnp.power(seq_len, 0.5) + if self.reducer_flag: + loss_weight_sum = self.allreduce(loss_weight) / self.device_num + loss_weight = loss_weight / loss_weight_sum + loss_weight *= 64. + + loss = loss * loss_weight + + loss_all = loss, l_fape_side, l_fape_backbone, l_anglenorm, predict_lddt_loss, \ + distogram_focal_loss, distogram_regression_loss, plddt2_regression_loss, mask_loss, \ + confidence_loss, cameo_loss + + return loss_all + + +class RegressionLosses(nn.Cell): + """Return various regressor losses""" + + def __init__(self, first_break, last_break, num_bins, bin_shift=True, beta=0.99, charbonnier_eps=1e-5, + reducer_flag=False): + super(RegressionLosses, self).__init__() + + self.beta = beta + self.charbonnier_eps = charbonnier_eps + + self.first_break = first_break + self.last_break = last_break + self.num_bins = num_bins + self.breaks = mnp.linspace(self.first_break, self.last_break, self.num_bins) + self.width = self.breaks[1] - self.breaks[0] + + bin_width = 2 + start_n = 1 + stop = self.num_bins * 2 + centers = mnp.divide(mnp.arange(start=start_n, stop=stop, step=bin_width), self.num_bins * 2.0) + self.centers = centers / (self.last_break - self.first_break) + self.first_break + + if bin_shift: + centers = mnp.linspace(self.first_break, self.last_break, self.num_bins) + self.centers = centers + 0.5 * self.width + self.mse = nn.MSELoss() + self.mae = nn.L1Loss() + self.bmse = BalancedMSE(first_break, last_break, num_bins, beta, reducer_flag) + + def construct(self, prediction, target): + """construct""" + target = mnp.clip(target, self.centers[0], self.centers[-1]) + + mse = self.mse(prediction, target) + mae = self.mae(prediction, target) + bmse = self.bmse(prediction, target) + return [mse, mae, bmse] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/nn_arch.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/nn_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..13b4ad3c425724efbcda2e76a9282f564ac96e22 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/megaassessment/nn_arch.py @@ -0,0 +1,350 @@ +# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""nn arch""" + +from collections import defaultdict +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore import Parameter, load_checkpoint + +from ....common.utils import dgram_from_positions, pseudo_beta_fn, atom37_to_torsion_angles +from ...cell.initializer import lecun_init +from ..megafold.module.template_embedding import TemplateEmbedding +from ..megafold.module.evoformer import Evoformer +from ..megafold.module.structure import StructureModule +from ..megafold.module.head import DistogramHead, PredictedLDDTHead, EstogramHead +from ..megafold.nn_arch import caculate_constant_array, megafold + + +def load_weights(model_path, config): + """ + Load checkpoint as parameter dict, support both npz file and mindspore checkpoint file. + """ + ms_ckpt = load_checkpoint(model_path) + weights = defaultdict(str) + for msname in ms_ckpt: + if "msa_stack" in msname and "extra" not in msname: + for i in range(config.evoformer.msa_stack_num): + temp_name = msname.split(".") + temp_name.insert(1, str(i)) + infer_name = "fold." + ".".join(temp_name) + weights[infer_name] = ms_ckpt[msname].data.asnumpy()[i] + + for i in range(8): + temp_name = msname.split(".") + temp_name.insert(1, str(i)) + infer_name = "assessment." + ".".join(temp_name) + weights[infer_name] = ms_ckpt[msname].data.asnumpy()[i] + else: + infer_name = "fold." + msname + weights[infer_name] = ms_ckpt[msname].data.asnumpy() + infer_name = "assessment." + msname + weights[infer_name] = ms_ckpt[msname].data.asnumpy() + + parameter_dict = defaultdict(str) + for name in weights: + parameter_dict[name] = Parameter(Tensor(weights[name]), name=name) + return parameter_dict + + +class CombineModel(nn.Cell): + """Combine MegaFold and MegaAssessment""" + + def __init__(self, config, mixed_precision): + super(CombineModel, self).__init__() + self.fold = megafold(config, mixed_precision=mixed_precision) + config.model.evoformer.extra_msa_stack_num = 4 + config.model.evoformer.msa_stack_num = 8 + self.assessment = MegaAssessment(config, mixed_precision=mixed_precision) + + def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair, decoy_atom_positions=None, + decoy_atom_mask=None, run_pretrain=True): + """construct""" + if run_pretrain: + out = self.fold(target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, + extra_has_deletion, extra_deletion_value, extra_msa_mask, residx_atom37_to_atom14, + atom37_atom_exists, residue_index, prev_pos, prev_msa_first_row, prev_pair) + else: + out = self.assessment(target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, + extra_has_deletion, extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair, decoy_atom_positions, + decoy_atom_mask) + return out + + +class MegaAssessment(nn.Cell): + """MegaAssessment""" + + def __init__(self, config, mixed_precision): + super(MegaAssessment, self).__init__() + + self.cfg = config + + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.is_training = self.cfg.is_training + self.recycle_pos = self.cfg.model.recycle_pos + self.recycle_features = self.cfg.model.recycle_features + self.max_relative_feature = self.cfg.model.max_relative_feature + self.num_bins = self.cfg.model.prev_pos.num_bins + self.min_bin = self.cfg.model.prev_pos.min_bin + self.max_bin = self.cfg.model.prev_pos.max_bin + self.template_enabled = self.cfg.model.template.enabled + self.template_embed_torsion_angles = self.cfg.model.template.embed_torsion_angles + self.extra_msa_stack_num = self.cfg.model.evoformer.extra_msa_stack_num + self.msa_stack_num = self.cfg.model.evoformer.msa_stack_num + self.chi_atom_indices, self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, \ + self.indices0, self.indices1 = caculate_constant_array(self.cfg.seq_length) + + self.preprocess_1d = nn.Dense(self.cfg.model.common.target_feat_dim, self.cfg.model.msa_channel, + weight_init=lecun_init(self.cfg.model.common.target_feat_dim)) + self.preprocess_msa = nn.Dense(self.cfg.model.common.msa_feat_dim, self.cfg.model.msa_channel, + weight_init=lecun_init(self.cfg.model.common.msa_feat_dim)) + self.left_single = nn.Dense(self.cfg.model.common.target_feat_dim, self.cfg.model.pair_channel, + weight_init=lecun_init(self.cfg.model.common.target_feat_dim)) + self.right_single = nn.Dense(self.cfg.model.common.target_feat_dim, self.cfg.model.pair_channel, + weight_init=lecun_init(self.cfg.model.common.target_feat_dim)) + self.prev_pos_linear = nn.Dense(self.cfg.model.common.dgram_dim, self.cfg.model.pair_channel, + weight_init=lecun_init(self.cfg.model.common.dgram_dim)) + self.pair_activations = nn.Dense(self.cfg.model.common.pair_in_dim, self.cfg.model.pair_channel, + weight_init=lecun_init(self.cfg.model.common.pair_in_dim)) + self.extra_msa_one_hot = nn.OneHot(depth=23, axis=-1) + self.template_aatype_one_hot = nn.OneHot(depth=22, axis=-1) + self.prev_msa_first_row_norm = nn.LayerNorm([256,], epsilon=1e-5) + self.prev_pair_norm = nn.LayerNorm([128,], epsilon=1e-5) + self.one_hot = nn.OneHot(depth=self.cfg.model.max_relative_feature * 2 + 1, axis=-1) + self.extra_msa_activations = nn.Dense(25, self.cfg.model.extra_msa_channel, weight_init=lecun_init(25)) + self.template_embedding = TemplateEmbedding(self.cfg.model, self.is_training, mixed_precision) + + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.template_single_embedding = nn.Dense(57, self.cfg.model.msa_channel, + weight_init= + lecun_init(57, initializer_name='relu')) + self.template_projection = nn.Dense(self.cfg.model.msa_channel, self.cfg.model.msa_channel, + weight_init=lecun_init(self.cfg.model.msa_channel, + initializer_name='relu')) + self.relu = nn.ReLU() + self.single_activations = nn.Dense(self.cfg.model.msa_channel, self.cfg.model.seq_channel, + weight_init=lecun_init(self.cfg.model.msa_channel)) + extra_msa_stack = nn.CellList() + for _ in range(self.extra_msa_stack_num): + extra_msa_block = Evoformer(self.cfg.model, + msa_act_dim=64, + pair_act_dim=128, + is_extra_msa=True, + is_training=self.is_training, + batch_size=None) + extra_msa_stack.append(extra_msa_block) + self.extra_msa_stack = extra_msa_stack + if self.is_training: + msa_stack = nn.CellList() + for _ in range(self.msa_stack_num): + msa_block = Evoformer(self.cfg.model, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + is_training=self.is_training, + batch_size=None) + msa_stack.append(msa_block) + self.msa_stack = msa_stack + else: + self.msa_stack = Evoformer(self.cfg.model, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + is_training=self.is_training, + batch_size=self.msa_stack_num) + self.idx_evoformer_block = Parameter(Tensor(0, mstype.int32), requires_grad=False) + self.evoformer_num_block_eval = Tensor(self.msa_stack_num, mstype.int32) + + self.structure_module = StructureModule(self.cfg.model, + self.cfg.model.seq_channel, + self.cfg.model.pair_channel, + self.cfg.seq_length) + + self.module_lddt = PredictedLDDTHead(self.cfg.model.heads.predicted_lddt, + self.cfg.model.seq_channel) + self.module_distogram = DistogramHead(self.cfg.model.heads.distogram, + self.cfg.model.pair_channel) + if self.is_training: + self.module_lddt_decoy = PredictedLDDTHead(self.cfg.model.heads.predicted_lddt, + self.cfg.model.seq_channel*3) + self.module_estogram = EstogramHead(first_break=self.cfg.model.heads.distogram.first_break, + last_break=self.cfg.model.heads.distogram.last_break, + num_bins=self.cfg.model.heads.distogram.num_bins) + + self.norm_0 = LayerNormDense(self.cfg.model.msa_channel, self.cfg.model.seq_channel) + self.norm_1 = LayerNormDense(self.cfg.model.msa_channel, self.cfg.model.seq_channel) + self.norm_2 = LayerNormDense(self.cfg.model.msa_channel, self.cfg.model.seq_channel) + self.extra_msa_length = 4 + self.msa_cluster_length = 4 + + def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair, decoy_atom_positions, decoy_atom_mask): + """construct""" + decoy_pseudo_beta, decoy_pseudo_beta_mask = pseudo_beta_fn(aatype, decoy_atom_positions, + atom37_atom_exists) + extra_msa = mnp.zeros_like(extra_msa[:self.extra_msa_length]) + extra_has_deletion = mnp.zeros_like(extra_has_deletion[:self.extra_msa_length]) + extra_deletion_value = mnp.zeros_like(extra_deletion_value[:self.extra_msa_length]) + extra_msa_mask = mnp.zeros_like(extra_msa_mask[:self.extra_msa_length]) + msa_feat = mnp.concatenate((msa_feat[0:1], mnp.zeros_like(msa_feat[1:self.msa_cluster_length])), axis=0) + msa_mask = mnp.concatenate((msa_mask[0:1], mnp.zeros_like(msa_mask[1:self.msa_cluster_length])), axis=0) + template_aatype = mnp.concatenate((aatype[None], mnp.zeros_like(template_aatype[1:])), axis=0) + template_mask = mnp.concatenate((mnp.ones_like(template_mask[0:1]), mnp.zeros_like(template_mask[1:])), axis=0) + template_all_atom_masks = mnp.concatenate((decoy_atom_mask[None], template_all_atom_masks[1:]), axis=0) + template_all_atom_positions = mnp.concatenate((decoy_atom_positions[None], template_all_atom_positions[1:]), + axis=0) + template_pseudo_beta_mask = mnp.concatenate((decoy_pseudo_beta_mask[None], template_pseudo_beta_mask[1:]), + axis=0) + template_pseudo_beta = mnp.concatenate((decoy_pseudo_beta[None], template_pseudo_beta[1:]), axis=0) + + preprocess_1d = self.preprocess_1d(target_feat) + preprocess_msa = self.preprocess_msa(msa_feat) + msa_activations = mnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + left_single = self.left_single(target_feat) + right_single = self.right_single(target_feat) + pair_activations = P.ExpandDims()(left_single, 1) + P.ExpandDims()(right_single, 0) + mask_2d = P.ExpandDims()(seq_mask, 1) * P.ExpandDims()(seq_mask, 0) + if self.recycle_pos: + prev_pseudo_beta = pseudo_beta_fn(aatype, prev_pos, None) + dgram = dgram_from_positions(prev_pseudo_beta, self.num_bins, self.min_bin, self.max_bin, self._type) + pair_activations += self.prev_pos_linear(dgram) + + if self.recycle_features: + prev_msa_first_row = self.prev_msa_first_row_norm(prev_msa_first_row) + msa_activations = mnp.concatenate( + (mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0), msa_activations[1:, ...]), 0) + pair_activations += self.prev_pair_norm(prev_pair) + + if self.max_relative_feature: + offset = P.ExpandDims()(residue_index, 1) - P.ExpandDims()(residue_index, 0) + rel_pos = self.one_hot(mnp.clip(offset + self.max_relative_feature, 0, 2 * self.max_relative_feature)) + pair_activations += self.pair_activations(rel_pos) + + template_pair_representation = 0 + if self.template_enabled: + template_pair_representation = self.template_embedding(pair_activations, template_aatype, + template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, + template_pseudo_beta, mask_2d) + pair_activations += template_pair_representation + msa_1hot = self.extra_msa_one_hot(extra_msa) + extra_msa_feat = mnp.concatenate((msa_1hot, extra_has_deletion[..., None], extra_deletion_value[..., None]), + axis=-1) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + extra_msa_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(extra_msa_mask, extra_msa_mask), -1) + for i in range(self.extra_msa_stack_num): + extra_msa_activations, pair_activations = \ + self.extra_msa_stack[i](extra_msa_activations, pair_activations, extra_msa_mask, extra_msa_norm, + mask_2d) + template_activations = None + if self.template_enabled and self.template_embed_torsion_angles: + num_templ, num_res = template_aatype.shape + aatype_one_hot = self.template_aatype_one_hot(template_aatype) + torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask = atom37_to_torsion_angles( + template_aatype, template_all_atom_positions, template_all_atom_masks, self.chi_atom_indices, + self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, self.indices0, self.indices1) + template_features = mnp.concatenate([aatype_one_hot, + mnp.reshape(torsion_angles_sin_cos, [num_templ, num_res, 14]), + mnp.reshape(alt_torsion_angles_sin_cos, [num_templ, num_res, 14]), + torsion_angles_mask], axis=-1) + template_activations = self.template_single_embedding(template_features) + template_activations = self.relu(template_activations) + template_activations = self.template_projection(template_activations) + msa_activations = mnp.concatenate([msa_activations, template_activations], axis=0) + torsion_angle_mask = torsion_angles_mask[:, :, 2] + msa_mask = mnp.concatenate([msa_mask, torsion_angle_mask], axis=0) + + msa_mask_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(msa_mask, msa_mask), -1) + + msa_decoy = [] + msa_decoy += [self.norm_0(template_activations[0]),] + + if self.is_training: + for i in range(self.msa_stack_num): + msa_activations, pair_activations = self.msa_stack[i](msa_activations, pair_activations, msa_mask, + msa_mask_norm, mask_2d) + else: + self.idx_evoformer_block = self.idx_evoformer_block * 0 + while self.idx_evoformer_block < self.evoformer_num_block_eval: + msa_activations, pair_activations = self.msa_stack(msa_activations, + pair_activations, + msa_mask, + msa_mask_norm, + mask_2d, + self.idx_evoformer_block) + self.idx_evoformer_block += 1 + + msa_decoy += [self.norm_1(msa_activations[0]),] + msa_decoy += [self.norm_2(msa_activations[-4]),] + + single_activations = self.single_activations(msa_activations[0]) + + final_atom_positions, _, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj = \ + self.structure_module(single_activations, + pair_activations, + seq_mask, + aatype, + residx_atom37_to_atom14, + atom37_atom_exists) + predicted_lddt_logits = self.module_lddt(rp_structure_module) + dist_logits, bin_edges = self.module_distogram(pair_activations) + plddt_dist, pred_mask2d, _ = self.module_estogram(dist_logits, decoy_pseudo_beta, decoy_pseudo_beta_mask) + if self.is_training: + msa_decoy = mnp.concatenate(msa_decoy, axis=-1) + decoy_logits = self.module_lddt_decoy(msa_decoy) + out = dist_logits, bin_edges, atom14_pred_positions, final_affines, angles_sin_cos_new,\ + predicted_lddt_logits, structure_traj, sidechain_frames, sidechain_atom_pos,\ + um_angles_sin_cos_new, final_atom_positions, decoy_pseudo_beta, decoy_pseudo_beta_mask, \ + decoy_logits, plddt_dist, pred_mask2d + return out + return plddt_dist + + +class LayerNormDense(nn.Cell): + """layernorm and dense layer""" + def __init__(self, inchannel, out_channel): + super(LayerNormDense, self).__init__() + self.norm = nn.LayerNorm([inchannel,], epsilon=1e-5) + self.act = nn.Dense(inchannel, out_channel, weight_init=lecun_init(inchannel)).to_float(mstype.float16) + + def construct(self, single_act): + """construct""" + out = self.norm(single_act.astype(mstype.float32)).astype(mstype.float16) + out = self.act(out) + + return out diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/model.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9aefc8ff58ffb599cd92fea32ad35fcd01c52b79 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/model.py @@ -0,0 +1,92 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Models""" +from abc import ABCMeta, abstractmethod +import os +import ssl +import urllib.request +import mindspore as ms +import mindspore.common.dtype as mstype +from mindspore import load_checkpoint +from mindsponge.pipeline.cell.amp import amp_convert + + +class Model(metaclass=ABCMeta): + """Model""" + def __init__(self, checkpoint_url=None, network=None, name=None, white_list=None): + self.cache = None + self.ckpt_path = None + self.checkpoint_url = checkpoint_url + self.name = name + self.network = network + self.white_list = white_list + if ms.get_context("device_target") == "Ascend": + self.network.to_float(mstype.float16) + amp_convert(self.network, self.white_list) + self._check_initialize() + + @abstractmethod + def forward(self, data): + pass + + @abstractmethod + def backward(self, data): + pass + + @abstractmethod + def train_step(self): + pass + + @abstractmethod + def predict(self): + pass + + def set_cache(self, path): + self.cache = path + + def set_checkpoint_path(self, path): + self.ckpt_path = path + + def from_pretrained(self): + if not os.path.exists(self.checkpoint_path): + print("Download checkpoint to ", self.checkpoint_path) + # pylint: disable=protected-access + ssl._create_default_https_context = ssl._create_unverified_context + urllib.request.urlretrieve(self.checkpoint_url, self.checkpoint_path) + load_checkpoint(self.checkpoint_path, self.network) + + def _check_initialize(self): + if self.checkpoint_url is None: + raise ValueError("checkpoint url is not initialize, please check your init function") + if self.config is None: + raise ValueError("model config is not initialize, please check your init function") + if self.network is None: + raise ValueError("network is not initialize, please check your init function") + + @abstractmethod + def _jit_forward(self, data): + pass + + @abstractmethod + def _pynative_forward(self, data): + pass diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40ce4406f6d972d6df929a4df2b7d4ea29505712 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""multimer""" +from .multimer import Multimer +from .multimer_dataset import MultimerDataSet +from .multimer_configuration import multimer_configuration diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d27dd78d05d135f5c629cc6a40a9e8c96ae6cae --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""module""" diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_block.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..38bc86db70d9be3434a79dd26c9b87115b9d1d1e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_block.py @@ -0,0 +1,315 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evoformer""" + +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Parameter +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindsponge.common.geometry import apply_to_point, invert_point, vecs_from_tensor, \ + vecs_dot_vecs, vecs_sub, vecs_cross_vecs, vecs_scale, \ + rots_expand_dims, vecs_expand_dims, invert_rigids, rigids_mul_vecs +from mindsponge.pipeline.cell.initializer import lecun_init + + +def compute_chi_angles(aatype, # (B, N) + all_atom_pos, # (B, N, 37, 3) + all_atom_mask, # (B, N, 37) + chi_atom_indices, + chi_angles_mask, + indices0, + indices1, + batch_size=1): + """compute chi angles""" + + aatype = mnp.minimum(aatype, 20) + # Collect the atoms for the chi-angles. + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. + atom_indices = mnp.take(chi_atom_indices, aatype, axis=0) + + # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. + + # 4 seq_length 4 4 batch, sequence length, chis, atoms + seq_length = all_atom_pos.shape[1] + atom_indices = atom_indices.reshape((4, seq_length, 4, 4, 1)).astype("int32") + new_indices = P.Concat(4)((indices0, indices1, atom_indices)) + chis_atom_pos = P.GatherNd()(all_atom_pos, new_indices) + chis_mask = mnp.take(chi_angles_mask, aatype, axis=0) + chi_angle_atoms_mask = P.GatherNd()(all_atom_mask, new_indices) + + # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. + chi_angle_atoms_mask = P.ReduceProd()(chi_angle_atoms_mask, -1) + chis_mask = chis_mask * (chi_angle_atoms_mask).astype(mnp.float32) + all_chi_angles = [] + for i in range(batch_size): + template_chi_angles = multimer_rigids_compute_dihedral_angle(vecs_from_tensor(chis_atom_pos[i, :, :, 0, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 1, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 2, :]), + vecs_from_tensor(chis_atom_pos[i, :, :, 3, :])) + all_chi_angles.append(template_chi_angles) + chi_angles = mnp.stack(all_chi_angles, axis=0) + return chi_angles, chis_mask + + +def multimer_square_euclidean_distance(v1, v2, epsilon): + """multimer_square_euclidean_distance.""" + difference = vecs_sub(v1, v2) + distance = vecs_dot_vecs(difference, difference) + if epsilon: + distance = mnp.maximum(distance, epsilon) + return distance + + +def multimer_vecs_robust_norm(v, epsilon=1e-6): + """multime computes norm of vectors 'v'.""" + v_l2_norm = v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + if epsilon: + v_l2_norm = mnp.maximum(v_l2_norm, epsilon**2) + return mnp.sqrt(v_l2_norm) + + +def multimer_vecs_robust_normalize(v, epsilon=1e-6): + """multimer normalizes vectors 'v'.""" + norms = multimer_vecs_robust_norm(v, epsilon) + return (v[0] / norms, v[1] / norms, v[2] / norms) + + +def multimer_rots_from_two_vecs(e0_unnormalized, e1_unnormalized): + """multimer_rots_from_two_vecs.""" + e0 = multimer_vecs_robust_normalize(e0_unnormalized) + c = vecs_dot_vecs(e1_unnormalized, e0) + e1 = vecs_sub(e1_unnormalized, vecs_scale(e0, c)) + e1 = multimer_vecs_robust_normalize(e1) + e2 = vecs_cross_vecs(e0, e1) + + rots = (e0[0], e1[0], e2[0], + e0[1], e1[1], e2[1], + e0[2], e1[2], e2[2]) + return rots + + +def multimer_rigids_from_3_points(vec_a, vec_b, vec_c): + """Create multimer Rigids from 3 points. """ + m = multimer_rots_from_two_vecs( + e0_unnormalized=vecs_sub(vec_c, vec_b), + e1_unnormalized=vecs_sub(vec_a, vec_b)) + rigid = (m, vec_b) + return rigid + + +def multimer_rigids_get_unit_vector(point_a, point_b, point_c): + """multimer_rigids_get_unit_vector.""" + rigid = multimer_rigids_from_3_points(vecs_from_tensor(point_a), + vecs_from_tensor(point_b), + vecs_from_tensor(point_c)) + rot, trans = rigid + rotation = rots_expand_dims(rot, -1) + translation = vecs_expand_dims(trans, -1) + inv_rigid = invert_rigids((rotation, translation)) + rigid_vec = rigids_mul_vecs(inv_rigid, vecs_expand_dims(trans, -2)) + unit_vector = multimer_vecs_robust_normalize(rigid_vec) + return unit_vector + + +def multimer_rigids_compute_dihedral_angle(a, b, c, d): + """multimer_rigids_compute_dihedral_angle.""" + v1 = vecs_sub(a, b) + v2 = vecs_sub(b, c) + v3 = vecs_sub(d, c) + + c1 = vecs_cross_vecs(v1, v2) + c2 = vecs_cross_vecs(v3, v2) + c3 = vecs_cross_vecs(c2, c1) + + v2_mag = multimer_vecs_robust_norm(v2) + return mnp.arctan2(vecs_dot_vecs(c3, v2), v2_mag * vecs_dot_vecs(c1, c2)) + + +class MultimerInvariantPointAttention(nn.Cell): + """Invariant Point attention module.""" + + def __init__(self, num_head, num_scalar_qk, num_scalar_v, num_point_v, num_point_qk, num_channel, pair_dim): + """ + + Args: + pair_dim: pair representation dimension. + """ + + super(MultimerInvariantPointAttention, self).__init__() + + self._dist_epsilon = 1e-8 + self.num_head = num_head + self.num_scalar_qk = num_scalar_qk + self.num_scalar_v = num_scalar_v + self.num_point_v = num_point_v + self.num_point_qk = num_point_qk + self.num_channel = num_channel + self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 + \ + self.num_head * pair_dim + self.q_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_qk, + weight_init=lecun_init(self.num_channel), has_bias=False) + self.k_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_qk, + weight_init=lecun_init(self.num_channel), has_bias=False) + self.v_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_v, + weight_init=lecun_init(self.num_channel), has_bias=False) + self.q_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_qk, + weight_init=lecun_init(self.num_channel)) + self.k_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_qk, + weight_init=lecun_init(self.num_channel)) + self.v_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_v, + weight_init=lecun_init(self.num_channel)) + self.soft_max = nn.Softmax(axis=-2) + self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights") + self.attention_2d = nn.Dense(pair_dim, self.num_head, weight_init=lecun_init(pair_dim)) + self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros') + + self.point_weights = np.sqrt(1.0 / (max(num_point_qk, 1) * 9. / 2)) + self.scalar_weights = np.sqrt(1.0 / (max(num_scalar_qk, 1) * 1.)) + + def construct(self, inputs_1d, inputs_2d, mask, rotation, translation): + """Compute geometry-aware attention. + + Args: + inputs_1d: (N, C) 1D input embedding that is the basis for the + scalar queries. + inputs_2d: (N, M, C') 2D input embedding, used for biases and values. + mask: (N, 1) mask to indicate which elements of inputs_1d participate + in the attention. + rotation: describe the orientation of every element in inputs_1d + translation: describe the position of every element in inputs_1d + + Returns: + Transformation of the input embedding. + """ + num_residues, _ = inputs_1d.shape + + num_head = self.num_head + attn_logits = 0. + num_point_qk = self.num_point_qk + point_weights = self.point_weights + + trainable_point_weights = mnp.logaddexp(self.trainable_point_weights, + mnp.zeros_like(self.trainable_point_weights)) + point_weights = point_weights * trainable_point_weights + + q_point_local = self.q_point_local(inputs_1d) + q_point_local = mnp.reshape(q_point_local, (num_residues, num_head, num_point_qk * 3)) + q_point_local = mnp.split(q_point_local, 3, axis=-1) + q_point_local = (ops.Squeeze()(q_point_local[0]), ops.Squeeze()(q_point_local[1]), + ops.Squeeze()(q_point_local[2])) + # Project query points into global frame. + q_point_global = apply_to_point(rotation, translation, q_point_local, 2) + q_point = [q_point_global[0][:, None, :, :], q_point_global[1][:, None, :, :], q_point_global[2][:, None, :, :]] + + k_point_local = self.k_point_local(inputs_1d) + k_point_local = mnp.reshape(k_point_local, (num_residues, num_head, num_point_qk * 3)) + k_point_local = mnp.split(k_point_local, 3, axis=-1) + k_point_local = (ops.Squeeze()(k_point_local[0]), ops.Squeeze()(k_point_local[1]), + ops.Squeeze()(k_point_local[2])) + + # Project query points into global frame. + k_point_global = apply_to_point(rotation, translation, k_point_local, 2) + k_point = [k_point_global[0][None, :, :, :], k_point_global[1][None, :, :, :], k_point_global[2][None, :, :, :]] + + dist2 = multimer_square_euclidean_distance(q_point, k_point, epsilon=0.) + + attn_qk_point = -0.5 * mnp.sum(point_weights[:, None] * dist2, axis=-1) + attn_logits += attn_qk_point + + num_scalar_qk = self.num_scalar_qk + + scalar_weights = self.scalar_weights + q_scalar = self.q_scalar(inputs_1d) + q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk]) + + k_scalar = self.k_scalar(inputs_1d) + k_scalar = mnp.reshape(k_scalar, [num_residues, num_head, num_scalar_qk]) + + q_scalar *= scalar_weights + q = mnp.swapaxes(q_scalar, -2, -3) + k = mnp.swapaxes(k_scalar, -2, -3) + attn_qk_scalar = ops.matmul(q, mnp.swapaxes(k, -2, -1)) + attn_qk_scalar = mnp.swapaxes(attn_qk_scalar, -2, -3) + attn_qk_scalar = mnp.swapaxes(attn_qk_scalar, -2, -1) + attn_logits += attn_qk_scalar + + attention_2d = self.attention_2d(inputs_2d) + attn_logits += attention_2d + + mask_2d = mask * mnp.swapaxes(mask, -1, -2) + attn_logits -= 1e5 * (1. - mask_2d[..., None]) + attn_logits *= mnp.sqrt(1. / 3) + attn = self.soft_max(attn_logits) + + num_scalar_v = self.num_scalar_v + v_scalar = self.v_scalar(inputs_1d) + v_scalar = mnp.reshape(v_scalar, [num_residues, num_head, num_scalar_v]) + + attn_tmp = mnp.swapaxes(attn, -1, -2) + attn_tmp = mnp.swapaxes(attn_tmp, -2, -3) + result_scalar = ops.matmul(attn_tmp, mnp.swapaxes(v_scalar, -2, -3)) + result_scalar = mnp.swapaxes(result_scalar, -2, -3) + + num_point_v = self.num_point_v + + v_point_local = self.v_point_local(inputs_1d) + v_point_local = mnp.reshape(v_point_local, (num_residues, num_head, num_point_v * 3)) + v_point_local = mnp.split(v_point_local, 3, axis=-1) + v_point_local = (ops.Squeeze()(v_point_local[0]), ops.Squeeze()(v_point_local[1]), + ops.Squeeze()(v_point_local[2])) + # Project query points into global frame. + v_point_global = apply_to_point(rotation, translation, v_point_local, 2) + v_point = [v_point_global[0][None], v_point_global[1][None], v_point_global[2][None]] + + result_point_global = [mnp.sum(attn[..., None] * v_point[0], axis=-3), + mnp.sum(attn[..., None] * v_point[1], axis=-3), + mnp.sum(attn[..., None] * v_point[2], axis=-3) + ] + + num_query_residues, _ = inputs_1d.shape + + result_scalar = mnp.reshape(result_scalar, [num_query_residues, -1]) + + output_feature1 = result_scalar + + result_point_global = [mnp.reshape(result_point_global[0], [num_query_residues, -1]), + mnp.reshape(result_point_global[1], [num_query_residues, -1]), + mnp.reshape(result_point_global[2], [num_query_residues, -1])] + result_point_local = invert_point(result_point_global, rotation, translation, 1) + output_feature20 = result_point_local[0] + output_feature21 = result_point_local[1] + output_feature22 = result_point_local[2] + point_norms = multimer_vecs_robust_norm(result_point_local, self._dist_epsilon) + output_feature3 = point_norms + + result_attention_over_2d = ops.matmul(mnp.swapaxes(attn, 1, 2), inputs_2d) + output_feature4 = mnp.reshape(result_attention_over_2d, [num_query_residues, -1]) + final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21, + output_feature22, output_feature3, output_feature4], axis=-1) + final_result = self.output_projection(final_act) + return final_result diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_evoformer.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_evoformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c38431f7a31449637e7b03f26d696d1b5dcaa8a2 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_evoformer.py @@ -0,0 +1,120 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evoformer""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindsponge.pipeline.cell import MSARowAttentionWithPairBias, Transition, OuterProductMean, \ + TriangleAttention, TriangleMultiplication, \ + MSAColumnGlobalAttention, MSAColumnAttention + + +class MultimerEvoformer(nn.Cell): + '''multimerevoformer''' + + def __init__(self, config, msa_act_dim, pair_act_dim, is_extra_msa, batch_size): + super(MultimerEvoformer, self).__init__() + if is_extra_msa: + self.slice_cfg = config.slice.extra_msa_stack + else: + self.slice_cfg = config.slice.msa_stack + self.config = config.evoformer + self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( + self.config.msa_row_attention_with_pair_bias.num_head, + msa_act_dim, + self.config.msa_row_attention_with_pair_bias.gating, + msa_act_dim, + pair_act_dim, + batch_size, + self.slice_cfg.msa_row_attention_with_pair_bias) + + self.msa_transition = Transition(self.config.msa_transition.num_intermediate_factor, + msa_act_dim, + batch_size, + self.slice_cfg.msa_transition) + + self.outer_product_mean = OuterProductMean(self.config.outer_product_mean.num_outer_channel, + msa_act_dim, + pair_act_dim, + batch_size, + self.slice_cfg.outer_product_mean) + + self.triangle_attention_starting_node = TriangleAttention( + self.config.triangle_attention_starting_node.orientation, + self.config.triangle_attention_starting_node.num_head, + pair_act_dim, + self.config.triangle_attention_starting_node.gating, + pair_act_dim, + batch_size, + self.slice_cfg.triangle_attention_starting_node) + + self.triangle_attention_ending_node = TriangleAttention(self.config.triangle_attention_ending_node.orientation, + self.config.triangle_attention_ending_node.num_head, + pair_act_dim, + self.config.triangle_attention_ending_node.gating, + pair_act_dim, + batch_size, + self.slice_cfg.triangle_attention_ending_node) + + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + pair_act_dim, + batch_size, + self.slice_cfg.pair_transition) + + self.triangle_multiplication_outgoing = TriangleMultiplication( + self.config.triangle_multiplication_outgoing.num_intermediate_channel, + self.config.triangle_multiplication_outgoing.equation, + layer_norm_dim=pair_act_dim, + batch_size=batch_size) + + self.triangle_multiplication_incoming = TriangleMultiplication( + self.config.triangle_multiplication_incoming.num_intermediate_channel, + self.config.triangle_multiplication_incoming.equation, + layer_norm_dim=pair_act_dim, + batch_size=batch_size) + if is_extra_msa: + self.attn_mod = MSAColumnGlobalAttention(self.config.msa_column_attention.num_head, + self.config.msa_column_attention.gating, + msa_act_dim, + batch_size, + self.slice_cfg.msa_column_global_attention) + else: + self.attn_mod = MSAColumnAttention(self.config.msa_column_attention.num_head, + msa_act_dim, + self.config.msa_column_attention.gating, + msa_act_dim, + batch_size, + self.slice_cfg.msa_column_attention) + + def construct(self, msa_act, pair_act, msa_mask, extra_msa_norm, pair_mask, index=None): + '''construct''' + pair_act = P.Add()(pair_act, self.outer_product_mean(msa_act, msa_mask, extra_msa_norm, index)) + msa_act = P.Add()(msa_act, self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act, index)) + msa_act = P.Add()(msa_act, self.attn_mod(msa_act, msa_mask, index)) + msa_act = P.Add()(msa_act, self.msa_transition(msa_act, index)) + pair_act = P.Add()(pair_act, self.triangle_multiplication_outgoing(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_multiplication_incoming(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_attention_starting_node(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_attention_ending_node(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.pair_transition(pair_act, index)) + return msa_act, pair_act diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_head.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..cf3ec50f685f232afd313e7acaf0fd2f441149c3 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_head.py @@ -0,0 +1,55 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindsponge.pipeline.cell.initializer import lecun_init + + +class PredictedLDDTHead(nn.Cell): + """Head to predict the per-residue LDDT to be used as a confidence measure.""" + + def __init__(self, config, seq_channel): + super().__init__() + self.config = config + self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5) + self.act_0 = nn.Dense(seq_channel, self.config.num_channels, + weight_init=lecun_init(seq_channel, initializer_name='relu') + ).to_float(mstype.float16) + self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, + weight_init=lecun_init(self.config.num_channels, initializer_name='relu') + ).to_float(mstype.float16) + self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros' + ).to_float(mstype.float16) + self.relu = nn.ReLU() + + def construct(self, rp_structure_module): + """Builds ExperimentallyResolvedHead module.""" + act = rp_structure_module + act = self.input_layer_norm(act.astype(mstype.float32)) + act = self.act_0(act) + act = self.relu(act.astype(mstype.float32)) + act = self.act_1(act) + act = self.relu(act.astype(mstype.float32)) + logits = self.logits(act) + return logits diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_structure.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_structure.py new file mode 100644 index 0000000000000000000000000000000000000000..f990ffd1f82d5e29a46cfbe202a472ddfac898dc --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_structure.py @@ -0,0 +1,252 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops import functional as F +from .....common import residue_constants +from ....cell.initializer import lecun_init +from .....common.utils import torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, \ + atom14_to_atom37 +from .....common.geometry import initial_affine, quaternion_to_tensor, pre_compose, vecs_scale,\ + vecs_to_tensor, vecs_expand_dims, rots_expand_dims +from .multimer_block import MultimerInvariantPointAttention + + +class MultiRigidSidechain(nn.Cell): + """Class to make side chain atoms.""" + + def __init__(self, config, single_repr_dim): + super().__init__() + self.config = config + self.input_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.input_projection_1 = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.relu = nn.ReLU() + self.resblock1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, + initializer_name='relu')) + self.resblock2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.resblock1_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.resblock2_1 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.unnormalized_angles = nn.Dense(self.config.num_channel, 14, + weight_init=lecun_init(self.config.num_channel)) + self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group) + self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions) + self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask) + self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame) + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=1e-12) + + def construct(self, rotation, translation, act, initial_act, aatype): + """Predict side chains using rotation and translation representations. + + Args: + rotation: The rotation matrices. + translation: A translation matrices. + act: updated pair activations from structure module + initial_act: initial act representations (input of structure module) + aatype: Amino acid type representations + + Returns: + angles, positions and new frames + """ + + act1 = self.input_projection(self.relu(act)) + init_act1 = self.input_projection_1(self.relu(initial_act)) + # Sum the activation list (equivalent to concat then Linear). + act = act1 + init_act1 + + # Mapping with some residual blocks. + # resblock1 + old_act = act + act = self.resblock1(self.relu(act)) + act = self.resblock2(self.relu(act)) + act += old_act + # resblock2 + old_act = act + act = self.resblock1_1(self.relu(act)) + act = self.resblock2_1(self.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = self.unnormalized_angles(self.relu(act)) + + unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2]) + angles = self.l2_normalize(unnormalized_angles) + + backb_to_global = ((rotation[0], rotation[1], rotation[2], + rotation[3], rotation[4], rotation[5], + rotation[6], rotation[7], rotation[8]), + (translation[0], translation[1], translation[2])) + + all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles, + self.restype_rigid_group_default_frame) + + pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, + self.restype_atom14_to_rigid_group, + self.restype_atom14_rigid_group_positions, + self.restype_atom14_mask) + + atom_pos = pred_positions + frames = all_frames_to_global + res = (angles, unnormalized_angles, atom_pos, frames) + return res + + +class MultimerFoldIteration(nn.Cell): + """A single iteration of the main structure module loop.""" + + def __init__(self, config, pair_dim, single_repr_dim): + super().__init__() + self.config = config + self.drop_out = nn.Dropout(keep_prob=0.9) + self.attention_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition = nn.Dense(self.config.num_channel, config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.relu = nn.ReLU() + self.affine_update = nn.Dense(self.config.num_channel, 6, weight_init='zeros') + self.attention_module = MultimerInvariantPointAttention(self.config.num_head, + self.config.num_scalar_qk, + self.config.num_scalar_v, + self.config.num_point_v, + self.config.num_point_qk, + self.config.num_channel, + pair_dim) + self.mu_side_chain = MultiRigidSidechain(self.config.sidechain, single_repr_dim) + + def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype): + """construct""" + attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation) + act += attn + act = self.drop_out(act) + act = self.attention_layer_norm(act) + # Transition + input_act = act + act = self.transition(act) + act = self.relu(act) + act = self.transition_1(act) + act = self.relu(act) + act = self.transition_2(act) + + act += input_act + act = self.drop_out(act) + act = self.transition_layer_norm(act) + # This block corresponds to + # Jumper et al. (2021) Alg. 23 "Backbone update" + # Affine update + affine_update = self.affine_update(act) + quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update) + translation1 = vecs_scale(translation, 20.0) + rotation1 = rotation + angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames = \ + self.mu_side_chain(rotation1, translation1, act, initial_act, aatype) + affine_output = quaternion_to_tensor(quaternion, translation) + quaternion = F.stop_gradient(quaternion) + rotation = F.stop_gradient(rotation) + res = (act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames) + return res + + +class MultimerStructureModule(nn.Cell): + """StructureModule as a network head.""" + + def __init__(self, config, single_repr_dim, pair_dim): + super(MultimerStructureModule, self).__init__() + self.config = config.model.structure_module + self.seq_length = config.seq_length + self.fold_iteration = MultimerFoldIteration(self.config, pair_dim, single_repr_dim) + self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5) + self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5) + self.num_layer = self.config.num_layer + self.indice0 = Tensor( + np.arange(self.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32")) + self.traj_w = Tensor(np.array([1.] * 4 + [self.config.position_scale] * 3), mstype.float32) + + def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, atom37_atom_exists=None): + """construct""" + sequence_mask = seq_mask[:, None] + act = self.single_layer_norm(single) + initial_act = act + act = self.initial_projection(act) + quaternion, rotation, translation = initial_affine(self.seq_length) + act_2d = self.pair_layer_norm(pair) + + # folder iteration + atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, act_iter = \ + self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) + atom14_pred_positions = vecs_to_tensor(atom_pos)[-1] + sidechain_atom_pos = atom_pos + + atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, + residx_atom37_to_atom14, + atom37_atom_exists, + self.indice0) + structure_traj = affine_output_new * self.traj_w + final_affines = affine_output_new[-1] + final_atom_positions = atom37_pred_positions + final_atom_mask = atom37_atom_exists + rp_structure_module = act_iter + res = (final_atom_positions, final_atom_mask, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj) + return res + + def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, + aatype): + """iteration_operation""" + affine_init = () + angles_sin_cos_init = () + um_angles_sin_cos_init = () + atom_pos_batch = () + frames_batch = () + + for _ in range(self.num_layer): + act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames = \ + self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) + affine_init = affine_init + (affine_output[None, ...],) + angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],) + um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],) + atom_pos_batch += (mnp.concatenate(vecs_expand_dims(atom_pos, 0), axis=0)[:, None, ...],) + frames_batch += (mnp.concatenate(rots_expand_dims(frames[0], 0) + + vecs_expand_dims(frames[1], 0), axis=0)[:, None, ...],) + affine_output_new = mnp.concatenate(affine_init, axis=0) + angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0) + um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0) + frames_new = mnp.concatenate(frames_batch, axis=1) + atom_pos_new = mnp.concatenate(atom_pos_batch, axis=1) + res = (atom_pos_new, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames_new, act) + return res diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_template_embedding.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_template_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..316c343d499cef9a69e73100b634ef3dba4f3f2b --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/module/multimer_template_embedding.py @@ -0,0 +1,221 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''TEMPLATE''' +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import operations as P +from mindsponge.pipeline.cell.initializer import lecun_init +from mindsponge.common.utils import dgram_from_positions, pseudo_beta_fn +from mindsponge.common.residue_constants import atom_order +from mindsponge.pipeline.cell import TriangleAttention, Transition, TriangleMultiplication +from .multimer_block import multimer_rigids_get_unit_vector + + +class MultimerTemplatePairStack(nn.Cell): + '''multimer template pair stack''' + + def __init__(self, config): + super(MultimerTemplatePairStack, self).__init__() + self.config = config.template.template_pair_stack + self.num_block = self.config.num_block + batch_size = 0 + self.slice = config.slice.template_pair_stack + start_node_cfg = self.config.triangle_attention_starting_node + self.triangle_attention_starting_node = TriangleAttention(start_node_cfg.orientation, + start_node_cfg.num_head, + start_node_cfg.key_dim, + start_node_cfg.gating, + 64, + batch_size, + self.slice.triangle_attention_starting_node) + end_node_cfg = self.config.triangle_attention_ending_node + self.triangle_attention_ending_node = TriangleAttention(end_node_cfg.orientation, + end_node_cfg.num_head, + end_node_cfg.key_dim, + end_node_cfg.gating, + 64, + batch_size, + self.slice.triangle_attention_ending_node) + # Hard Code + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + 64, + batch_size, + self.slice.pair_transition) + + mul_outgoing_cfg = self.config.triangle_multiplication_outgoing + self.triangle_multiplication_outgoing = TriangleMultiplication(mul_outgoing_cfg.num_intermediate_channel, + mul_outgoing_cfg.equation, + layer_norm_dim=64, + batch_size=batch_size) + mul_incoming_cfg = self.config.triangle_multiplication_incoming + self.triangle_multiplication_incoming = TriangleMultiplication(mul_incoming_cfg.num_intermediate_channel, + mul_incoming_cfg.equation, + layer_norm_dim=64, + batch_size=batch_size) + + def construct(self, pair_act, pair_mask, index=None): + if not self.num_block: + return pair_act + + pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index) + pair_act = pair_act + self.pair_transition(pair_act, index) + return pair_act + + +class MultimerSingleTemplateEmbedding(nn.Cell): + '''multimer single template embedding''' + + def __init__(self, config, mixed_precision): + super(MultimerSingleTemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_bins = self.config.dgram_features.num_bins + self.min_bin = self.config.dgram_features.min_bin + self.max_bin = self.config.dgram_features.max_bin + + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.template_dgram_temp_dense = nn.Dense(39, self.num_channels, + weight_init=lecun_init(39, initializer_name='relu')) + self.template_mask_2d_temp_dense = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.aatype_temp_0 = nn.Dense(22, self.num_channels, + weight_init=lecun_init(22, initializer_name='relu')) + self.aatype_temp_1 = nn.Dense(22, self.num_channels, + weight_init=lecun_init(22, initializer_name='relu')) + self.unit_vector_0 = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.unit_vector_1 = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.unit_vector_2 = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.backbone_mask_2d_dense = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.embedding2d = nn.Dense(128, self.num_channels, + weight_init=lecun_init(128, initializer_name='relu')) + template_layers = nn.CellList() + for _ in range(self.config.template_pair_stack.num_block): + template_pair_stack_block = MultimerTemplatePairStack(config) + template_layers.append(template_pair_stack_block) + self.template_pair_stack = template_layers + + self.one_hot = nn.OneHot(depth=22, axis=-1) + self.n, self.ca, self.c = [atom_order[a] for a in ('N', 'CA', 'C')] + + layer_norm_dim = 64 + self.query_embedding_norm = nn.LayerNorm([128,], epsilon=1e-5) + self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5) + self.num_block = self.config.template_pair_stack.num_block + self.batch_block = 4 + + def construct(self, pair_activations, template_aatype, + template_all_atom_positions, template_all_atom_mask, + padding_mask_2d, multichain_mask_2d): + '''construct''' + num_templates = template_aatype.shape[0] + template_positions, template_pseudo_beta_mask = pseudo_beta_fn(template_aatype, + template_all_atom_positions, + template_all_atom_mask) + template_mask_2d_temp = P.ExpandDims()(template_pseudo_beta_mask, -1) * \ + P.ExpandDims()(template_pseudo_beta_mask, 1) + + template_mask_2d_temp *= multichain_mask_2d + template_dgram_temp = dgram_from_positions(template_positions, self.num_bins, self.min_bin, + self.max_bin, self._type) + template_dgram_temp *= template_mask_2d_temp[..., None] + act_tmp = self.template_dgram_temp_dense(template_dgram_temp) + act_tmp += self.template_mask_2d_temp_dense((P.ExpandDims()(template_mask_2d_temp, -1))) + aatype_temp = self.one_hot(template_aatype) + aatype_temp = P.Cast()(aatype_temp, self._type) + act_tmp += self.aatype_temp_0((P.ExpandDims()(aatype_temp, 1))) + act_tmp += self.aatype_temp_1((P.ExpandDims()(aatype_temp, 2))) + backbone_mask = (template_all_atom_mask[:, :, self.n] * + template_all_atom_mask[:, :, self.ca] * + template_all_atom_mask[:, :, self.c]) + unit_vector = multimer_rigids_get_unit_vector(template_all_atom_positions[:, :, self.n], + template_all_atom_positions[:, :, self.ca], + template_all_atom_positions[:, :, self.c]) + + backbone_mask_2d = (P.ExpandDims()(backbone_mask, -1)) * (P.ExpandDims()(backbone_mask, 1)) + backbone_mask_2d *= multichain_mask_2d + unit_vector = (P.ExpandDims()(backbone_mask_2d * unit_vector[0], -1), + P.ExpandDims()(backbone_mask_2d * unit_vector[1], -1), + P.ExpandDims()(backbone_mask_2d * unit_vector[2], -1)) + pair_activations = self.query_embedding_norm(pair_activations) + num_res, _, query_num_channels = pair_activations.shape + pair_init = mnp.zeros((num_templates, num_res, num_res, query_num_channels), dtype=self._type) + pair_activations = pair_init + pair_activations + act_tmp += self.unit_vector_0(unit_vector[0]) + act_tmp += self.unit_vector_1(unit_vector[1]) + act_tmp += self.unit_vector_2(unit_vector[2]) + act_tmp += self.backbone_mask_2d_dense(P.ExpandDims()(backbone_mask_2d, -1)) + act_tmp += self.embedding2d(pair_activations) + + act_tmp = P.Split(0, self.batch_block)(act_tmp) + scan_init = mnp.zeros((num_res, num_res, self.num_channels), dtype=self._type) + act = () + for i in range(self.batch_block): + act = act + (P.Squeeze()(act_tmp[i]),) + + for i in range(self.batch_block): + act_batch = act[i] + for j in range(self.num_block): + act_batch = self.template_pair_stack[j](act_batch, padding_mask_2d) + scan_init += self.output_layer_norm(act_batch) + return scan_init + + +class MultimerTemplateEmbedding(nn.Cell): + '''multimer template embedding''' + + def __init__(self, config, mixed_precision=True): + super(MultimerTemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.template_embedder = MultimerSingleTemplateEmbedding(config, mixed_precision) + self.relu = nn.ReLU() + self.output_linear = nn.Dense(self.num_channels, config.pair_channel, + weight_init=lecun_init(self.num_channels, initializer_name='relu')) + + def construct(self, pair_activations, template_aatype, template_all_atom_mask, template_all_atom_positions, + padding_mask_2d, multichain_mask_2d): + '''construct''' + num_templates = template_aatype.shape[0] + embedding = self.template_embedder(pair_activations, template_aatype, + template_all_atom_positions, + template_all_atom_mask, + padding_mask_2d, + multichain_mask_2d) + embedding = embedding / num_templates + embedding = self.relu(embedding) + return self.output_linear(embedding) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer.py new file mode 100644 index 0000000000000000000000000000000000000000..df45e503d61d7827b9aa7740b5cff6af75a51ed8 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer.py @@ -0,0 +1,118 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""multimer""" +import time +from mindspore import jit, context, nn +from mindspore.common import mutable +from mindspore import Tensor +from .nn_arch import MultimerArch, compute_confidence +from ..model import Model + + +class Multimer(Model): + """Multimer""" + name = "Multimer" + feature_list = ['aatype', 'residue_index', 'template_aatype', 'template_all_atom_mask', + 'template_all_atom_positions', 'asym_id', 'sym_id', 'entity_id', 'seq_mask', 'msa_mask', + 'target_feat', 'msa_feat', 'extra_msa', 'extra_deletion_matrix', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists', + 'prev_pos', 'prev_msa_first_row', 'prev_pair'] + + def __init__(self, config): + context.set_context(memory_optimize_level="O1", max_call_depth=6000) + if context.get_context("device_target") == "GPU": + self.mixed_precision = False + context.set_context(graph_kernel_flags="--disable_expand_ops=Softmax --disable_cluster_ops=ReduceSum " + "--composite_op_limit_size=50", enable_graph_kernel=True) + else: + self.mixed_precision = True + + self.config = config + self.use_jit = self.config.use_jit + self.white_list = (nn.Softmax, nn.LayerNorm) + self.checkpoint_url = \ + 'https://download.mindspore.cn/mindscience/mindsponge/Multimer/checkpoint/Multimer_Model_1.ckpt' + self.checkpoint_path = "./Multimer_Model_1.ckpt" + self.network = MultimerArch(self.config, self.mixed_precision) + super().__init__(self.checkpoint_url, self.network, self.name, self.white_list) + + def forward(self, data): + if self.use_jit: + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = self._jit_forward(data) + else: + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = self._pynative_forward(data) + return prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits + + # pylint: disable=arguments-differ + def predict(self, inputs, num_recycle=1): + num_residues = inputs["num_residues"] + recycle_feature_name = self.feature_list[:-3] + prev_pos = Tensor(inputs['prev_pos']) + prev_msa_first_row = Tensor(inputs['prev_msa_first_row']) + prev_pair = Tensor(inputs['prev_pair']) + for recycle in range(num_recycle): + data = {} + for key in recycle_feature_name: + data[key] = Tensor(inputs[key][recycle]) + data['prev_pos'] = prev_pos + data['prev_msa_first_row'] = prev_msa_first_row + data['prev_pair'] = prev_pair + data = mutable(data) + t1 = time.time() + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = self.forward(data) + t2 = time.time() + print(round(t2 - t1, 2)) + final_atom_positions = prev_pos.asnumpy()[:num_residues] + final_atom_mask = data['atom37_atom_exists'].asnumpy()[:num_residues] + predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:num_residues] + confidence, plddt = compute_confidence(predicted_lddt_logits, return_lddt=True) + b_factors = plddt[:, None] * final_atom_mask + return final_atom_positions, final_atom_mask, confidence, b_factors + + def loss(self, data): + pass + + def grad_operations(self, gradient): + pass + + @jit + def backward(self, data): + pass + + def train_step(self): + pass + + @jit + def _jit_forward(self, data): + feat = [] + for key in self.feature_list: + feat.append(data[key]) + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = self.network(*feat) + return prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits + + def _pynative_forward(self, data): + feat = [] + for key in self.feature_list: + feat.append(data[key]) + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = self.network(*feat) + return prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_configuration.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..c126bd976071098be160bbc410fd0a76b6749dc8 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_configuration.py @@ -0,0 +1,32 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""multimer configuration""" +multimer_configuration = { + "predict_256": "https://download.mindspore.cn/mindscience/mindsponge/Multimer/config/predict_256.yaml", + "predict_512": "https://download.mindspore.cn/mindscience/mindsponge/Multimer/config/predict_512.yaml", + "predict_768": "https://download.mindspore.cn/mindscience/mindsponge/Multimer/config/predict_768.yaml", + "predict_1024": "https://download.mindspore.cn/mindscience/mindsponge/Multimer/config/predict_1024.yaml", + "predict_1280": "https://download.mindspore.cn/mindscience/mindsponge/Multimer/config/predict_1280.yaml", + "predict_1536": "https://download.mindspore.cn/mindscience/mindsponge/Multimer/config/predict_1536.yaml", + "predict_1792": "https://download.mindspore.cn/mindscience/mindsponge/Multimer/config/predict_1792.yaml" +} diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_data.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_data.py new file mode 100644 index 0000000000000000000000000000000000000000..ce8f922b0df9f37ed77a6a008105b3a22c103c32 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_data.py @@ -0,0 +1,341 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""multimer data""" +import numpy as np +from ....common import residue_constants +from ...dataset import curry1 +from ....data.data_transform import make_atom14_masks +from .multimer_feature import NUM_RES, NUM_MSA_SEQ, NUM_EXTRA_SEQ, NUM_TEMPLATES + + +@curry1 +def dict_filter_key(feature=None, feature_list=None): + feature = {k: v for k, v in feature.items() if k in feature_list} + return feature + + +@curry1 +def dict_replace_key(feature=None, replaced_key=None): + assert len(replaced_key) == 2 + origin_key, new_key = replaced_key + if origin_key in feature: + feature[new_key] = feature.pop(origin_key) + return feature + + +@curry1 +def dict_cast(feature=None, cast_type=None, filtered_list=None): + assert len(cast_type) == 2 + origin_type = cast_type[0] + new_type = cast_type[1] + for k, v in feature.items(): + if k not in filtered_list: + if v.dtype == origin_type: + feature[k] = v.astype(new_type) + return feature + + +@curry1 +def dict_suqeeze(feature=None, filter_list=None, axis=None): + for k in filter_list: + if k in feature: + feat_dim = feature[k].shape[axis] + if isinstance(feat_dim, int) and feat_dim == 1: + feature[k] = np.squeeze(feature[k], axis=axis) + return feature + + +@curry1 +def dict_take(feature=None, filter_list=None, axis=None): + for k in filter_list: + if k in feature: + feature[k] = feature[k][axis] + return feature + + +@curry1 +def dict_del_key(feature=None, filter_list=None): + for k in filter_list: + if k in feature: + del feature[k] + return feature + + +@curry1 +def one_hot_convert(feature=None, key=None, axis=None): + if key in feature: + feature[key] = np.argmax(feature[key], axis=axis) + return feature + + +@curry1 +def correct_restypes(feature=None, key=None): + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = np.array(new_order_list, dtype=feature[key].dtype) + feature[key] = new_order[feature[key]] + return feature + + +@curry1 +def make_msa_profile(feature=None, axis=None, drop_mask_channel=False, eps=1e-10): + """Make_msa_profile.""" + mask = feature['msa_mask'][:, :, None] + value = np.eye(22)[feature['msa']] + feature['target_feat'] = np.eye(21)[feature['aatype']] + if drop_mask_channel: + mask = mask[..., 0] + mask_shape = mask.shape + value_shape = value.shape + broadcast_factor = 1. + value_size = value_shape[axis] + mask_size = mask_shape[axis] + if mask_size == 1: + broadcast_factor *= value_size + feature['msa_profile'] = np.sum(mask * value, axis=axis) / (np.sum(mask, axis=axis) * broadcast_factor + eps) + return feature + + +@curry1 +def sample_msa(feature=None, msa_feature_list=None, max_seq=None, seed=None): + """Sample MSA randomly.""" + if seed is not None: + np.random.seed(seed) + + logits = (np.clip(np.sum(feature['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6 + if 'cluster_bias_mask' not in feature: + cluster_bias_mask = np.pad( + np.zeros(feature['msa'].shape[0] - 1), (1, 0), constant_values=1.) + else: + cluster_bias_mask = feature['cluster_bias_mask'] + logits += cluster_bias_mask * 1e6 + z = np.random.gumbel(loc=0.0, scale=1.0, size=logits.shape) + index_order = np.argsort((logits + z), axis=-1, kind='quicksort', order=None) + sel_idx = index_order[:max_seq] + extra_idx = index_order[max_seq:] + for k in msa_feature_list: + if k in feature: + feature['extra_' + k] = feature[k][extra_idx] + feature[k] = feature[k][sel_idx] + + if seed is not None: + np.random.seed() + return feature + + +@curry1 +def make_masked_msa(feature=None, config=None, epsilon=1e-6, seed=None): + """create data for BERT on raw MSA.""" + if seed is not None: + np.random.seed(seed) + + random_aa = np.array([0.05] * 20 + [0., 0.], dtype=np.float32) + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * feature['msa_profile'] + + config.same_prob * np.eye(22)[feature['msa']]) + pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] + pad_shapes[-1][1] = 1 + mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob + categorical_probs = np.pad(categorical_probs, pad_shapes, constant_values=mask_prob) + sh = feature['msa'].shape + mask_position = (np.random.uniform(0., 1., sh) < config.replace_fraction).astype(np.float32) + mask_position *= feature['msa_mask'] + logits = np.log(categorical_probs + epsilon) + z = np.random.gumbel(loc=0.0, scale=1.0, size=logits.shape) + bert_msa = np.eye(logits.shape[-1], dtype=logits.dtype)[np.argmax(logits + z, axis=-1)] + bert_msa = (np.where(mask_position, + np.argmax(bert_msa, axis=-1), feature['msa'])) + bert_msa *= (feature['msa_mask'].astype(np.int64)) + if 'bert_mask' in feature: + feature['bert_mask'] *= mask_position.astype(np.float32) + else: + feature['bert_mask'] = mask_position.astype(jnp.float32) + feature['true_msa'] = feature['msa'] + feature['msa'] = bert_msa + + if seed is not None: + np.random.seed() + return feature + + +def softmax(x, axis): + """ Softmax func""" + x -= np.max(x, axis=axis, keepdims=True) + x = np.exp(x) / np.sum(np.exp(x), axis=axis, keepdims=True) + return x + + +def nearest_neighbor_clusters(feature, gap_agreement_weight=0., seed=None): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + if seed is not None: + np.random.seed(seed) + + weights = np.array( + [1.] * 21 + [gap_agreement_weight] + [0.], dtype=np.float32) + msa_mask = feature['msa_mask'] + msa_one_hot = np.eye(23)[feature['msa']] + extra_mask = feature['extra_msa_mask'] + extra_one_hot = np.eye(23)[feature['extra_msa']] + msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot + extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot + agreement = np.einsum('mrc, nrc->nm', extra_one_hot_masked, + weights * msa_one_hot_masked) + cluster_assignment = softmax(1e3 * agreement, axis=0) + cluster_assignment *= np.einsum('mr, nr->mn', msa_mask, extra_mask) + cluster_count = np.sum(cluster_assignment, axis=-1) + cluster_count += 1. + msa_sum = np.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked) + msa_sum += msa_one_hot_masked + feature['cluster_profile'] = msa_sum / cluster_count[:, None, None] + extra_deletion_matrix = feature['extra_deletion_matrix'] + deletion_matrix = feature['deletion_matrix'] + del_sum = np.einsum('nm, mc->nc', cluster_assignment, + extra_mask * extra_deletion_matrix) + del_sum += deletion_matrix + feature['cluster_deletion_mean'] = del_sum / cluster_count[:, None] + + if seed is not None: + np.random.seed() + return feature + + +def create_msa_feat(feature): + """Create and concatenate MSA features.""" + msa_1hot = np.eye(23)[feature['msa']] + deletion_matrix = feature['deletion_matrix'] + has_deletion = np.clip(deletion_matrix, 0., 1.)[..., None] + deletion_value = (np.arctan(deletion_matrix / 3.) * (2. / np.pi))[..., None] + deletion_mean_value = (np.arctan(feature['cluster_deletion_mean'] / 3.) * + (2. / np.pi))[..., None] + msa_feat = [ + msa_1hot, + has_deletion, + deletion_value, + feature['cluster_profile'], + deletion_mean_value + ] + feature['msa_feat'] = np.concatenate(msa_feat, axis=-1) + return feature + + +def make_atom14_mask(feature): + _, _, feature['residx_atom37_to_atom14'], feature['atom37_atom_exists'] = \ + make_atom14_masks(feature['aatype']) + return feature + + +MS_MIN32 = -2147483648 +MS_MAX32 = 2147483647 + + +def make_random_seed(size, seed_maker_t, low=MS_MIN32, high=MS_MAX32, random_recycle=False): + if random_recycle: + r = np.random.RandomState(seed_maker_t) + return r.uniform(size=size, low=low, high=high) + np.random.seed(seed_maker_t) + return np.random.uniform(size=size, low=low, high=high) + + +@curry1 +def random_crop_to_size(feature=None, feature_list=None, crop_size=None, max_templates=None, max_msa_clusters=None, + max_extra_msa=None, seed=None, random_recycle=None): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + seq_length = feature['seq_length'] + seq_length_int = int(seq_length) + num_templates = np.array(0, np.int32) + num_res_crop_size = np.minimum(seq_length, crop_size) + num_res_crop_size_int = int(num_res_crop_size) + + # Ensures that the cropping of residues and templates happens in the same way + # across ensembling iterations. + # Do not use for randomness that should vary in ensembling. + templates_crop_start = 0 + num_templates_crop_size = np.minimum(num_templates - templates_crop_start, max_templates) + num_templates_crop_size_int = int(num_templates_crop_size) + + num_res_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0, + high=seq_length_int - num_res_crop_size_int + 1, + random_recycle=random_recycle)) + + for k, v in feature.items(): + if k not in feature_list or ('template' not in k and NUM_RES not in feature_list.get(k)): + continue + + crop_sizes = [] + crop_starts = [] + for i, (dim_size, dim) in enumerate(zip(feature_list.get(k), v.shape)): + is_num_res = (dim_size == NUM_RES) + if i == 0 and k.startswith('template'): + crop_size_ = num_templates_crop_size_int + crop_start = templates_crop_start + else: + crop_start = num_res_crop_start if is_num_res else 0 + crop_size_ = (num_res_crop_size_int if is_num_res else (-1 if dim is None else dim)) + crop_sizes.append(crop_size_) + crop_starts.append(crop_start) + if len(v.shape) == 1: + feature[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0]] + elif len(v.shape) == 2: + feature[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1]] + elif len(v.shape) == 3: + feature[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1], + crop_starts[2]:crop_starts[2] + crop_sizes[2]] + else: + feature[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1], + crop_starts[2]:crop_starts[2] + crop_sizes[2], + crop_starts[3]:crop_starts[3] + crop_sizes[3]] + + feature["num_residues"] = feature["seq_length"] + feature["seq_length"] = num_res_crop_size + pad_size_map = { + NUM_RES: crop_size, + NUM_MSA_SEQ: max_msa_clusters, + NUM_EXTRA_SEQ: max_extra_msa, + NUM_TEMPLATES: max_templates, + } + + for k, v in feature.items(): + if k not in feature_list or k == "num_residues": + continue + shape = list(v.shape) + schema = feature_list.get(k) + assert len(shape) == len( + schema), f'Rank mismatch between shape and shape schema for {k}: {shape} vs {schema}' + + pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + if padding: + feature[k] = np.pad(v, padding) + feature[k].reshape(pad_size) + + return feature + + +def prev_initial(feature): + feature['prev_pos'] = np.zeros([feature['aatype'].shape[1], 37, 3]) + feature['prev_msa_first_row'] = np.zeros([feature['aatype'].shape[1], 256]) + feature['prev_pair'] = np.zeros([feature['aatype'].shape[1], feature['aatype'].shape[1], 128]) + return feature diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_dataset.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..99d3c3bc8e1a699f337e4d3bea415f71c4ad49ed --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_dataset.py @@ -0,0 +1,115 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""multimer dataset""" +import numpy as np +from mindspore import context +from .multimer_data import make_msa_profile, sample_msa, make_masked_msa, nearest_neighbor_clusters, \ + create_msa_feat, random_crop_to_size, \ + dict_cast, dict_filter_key, prev_initial, make_atom14_mask +from .multimer_feature import _inference_feature, _msa_feature_names +from ...dataset import data_process_run, DataSet + + +class MultimerDataSet(DataSet): + """MultimerDataSet""" + def __init__(self, config, seed=0): + self.config = config + self.in_memory = False + self.phase = None + self.feature_list = None + self.feature_names = _inference_feature + self.multimer_inputs() + + self.data_process = [ + make_msa_profile(axis=0), + sample_msa(msa_feature_list=_msa_feature_names, max_seq=self.config.data.num_msa, seed=seed), + make_masked_msa(config=self.config.data.masked_msa, seed=seed), + nearest_neighbor_clusters, + create_msa_feat, + make_atom14_mask, + random_crop_to_size(feature_list=self.feature_names, crop_size=self.config.seq_length, + max_templates=self.config.data.max_templates, + max_msa_clusters=self.config.max_msa_clusters, + max_extra_msa=self.config.max_extra_msa, + seed=seed, random_recycle=self.config.data.random_recycle), + ] + + self.tail_fns = [] + if context.get_context("device_target") == "GPU": + self.mixed_precision = False + else: + self.mixed_precision = True + + if self.mixed_precision: + data_cast_fns = [dict_cast([np.float64, np.float16], []), + dict_cast([np.float32, np.float16], []), + dict_cast([np.int64, np.int32], [])] + else: + data_cast_fns = [dict_cast([np.float64, np.float32], []), dict_cast([np.int64, np.int32], [])] + + self.tail_fns.extend([dict_filter_key(feature_list=self.feature_names), + prev_initial]) + self.tail_fns.extend(data_cast_fns) + super().__init__() + + def __getitem__(self, idx): + pass + + def __len__(self): + pass + + def multimer_inputs(self): + feature_list = ['aatype', 'residue_index', 'template_aatype', 'template_all_atom_mask', + 'template_all_atom_positions', 'asym_id', 'sym_id', 'entity_id', 'seq_mask', 'msa_mask', + 'target_feat', 'msa_feat', 'extra_msa', 'extra_deletion_matrix', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists', + 'prev_pos', 'prev_msa_first_row', 'prev_pair'] + self.feature_list = feature_list + + def process(self, data, label=None): + """process""" + res = {} + for _ in range(4): + features = data_process_run(data.copy(), self.data_process) + if res == {}: + res = {x: () for x in features.keys()} + for key in features.keys(): + if key == "num_residues": + res[key] = features[key] + else: + res[key] += (features[key][None],) + for key in res.keys(): + if key != 'num_residues': + res[key] = np.concatenate(res[key], axis=0) + features = res + features = data_process_run(features, self.tail_fns) + return features + + def download(self, path=None): + pass + + def data_parse(self, input_data, idx): + pass + + def create_iterator(self, num_epochs): + pass diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_feature.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..251b3e22ade987e76c06d610ce1c5f3b984b5cc6 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/multimer_feature.py @@ -0,0 +1,49 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""multimer feature""" +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' +_msa_feature_names = ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask'] + +_inference_feature = { + 'aatype': [NUM_RES], + 'residue_index': [NUM_RES], + 'template_aatype': [NUM_TEMPLATES, NUM_RES], + 'template_all_atom_mask': [NUM_TEMPLATES, NUM_RES, None], + 'template_all_atom_positions': [NUM_TEMPLATES, NUM_RES, None, None], + 'asym_id': [NUM_RES], + 'sym_id': [NUM_RES], + 'entity_id': [NUM_RES], + 'seq_mask': [NUM_RES], + 'msa_mask': [NUM_MSA_SEQ, NUM_RES], + 'target_feat': [NUM_RES, None], + 'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], + 'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_deletion_matrix': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], + 'residx_atom37_to_atom14': [NUM_RES, None], + 'atom37_atom_exists': [NUM_RES, None], + 'num_residues': [None] +} diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/nn_arch.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/nn_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..6d1e137b2d681fa6a8f810c25d5764e1d29498f4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/multimer/nn_arch.py @@ -0,0 +1,303 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""multimer model""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore import Parameter +from scipy.special import softmax +from ....common import residue_constants +from ....common.utils import dgram_from_positions, pseudo_beta_fn +from ....data.data_transform import get_chi_atom_pos_indices +from ...cell.initializer import lecun_init +from .module.multimer_block import compute_chi_angles +from .module.multimer_template_embedding import MultimerTemplateEmbedding +from .module.multimer_evoformer import MultimerEvoformer +from .module.multimer_structure import MultimerStructureModule +from .module.multimer_head import PredictedLDDTHead + + +def caculate_constant_array(seq_length): + '''constant array''' + chi_atom_indices = np.array(get_chi_atom_pos_indices()).astype(np.int32) + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = np.array(chi_angles_mask).astype(np.float32) + mirror_psi_mask = np.float32(np.asarray([1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]) + chi_pi_periodic = np.float32(np.array(residue_constants.chi_pi_periodic)) + + indices0 = np.arange(4).reshape((-1, 1, 1, 1, 1)).astype("int32") # 4 batch + indices0 = indices0.repeat(seq_length, axis=1) # seq_length sequence length + indices0 = indices0.repeat(4, axis=2) # 4 chis + indices0 = indices0.repeat(4, axis=3) # 4 atoms + + indices1 = np.arange(seq_length).reshape((1, -1, 1, 1, 1)).astype("int32") + indices1 = indices1.repeat(4, axis=0) + indices1 = indices1.repeat(4, axis=2) + indices1 = indices1.repeat(4, axis=3) + + constant_array = [chi_atom_indices, chi_angles_mask, mirror_psi_mask, chi_pi_periodic, indices0, indices1] + constant_array = [Tensor(val) for val in constant_array] + return constant_array + + +def compute_confidence(predicted_lddt_logits, return_lddt=False): + """compute confidence""" + + num_bins = predicted_lddt_logits.shape[-1] + bin_width = 1 / num_bins + start_n = bin_width / 2 + plddt = compute_plddt(predicted_lddt_logits, start_n, bin_width) + confidence = np.mean(plddt) + if return_lddt: + return confidence, plddt + + return confidence + + +def compute_plddt(logits, start_n, bin_width): + """Computes per-residue pLDDT from logits. + + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + + Returns: + plddt: [num_res] per-residue pLDDT. + """ + bin_centers = np.arange(start=start_n, stop=1.0, step=bin_width) + probs = softmax(logits, axis=-1) + predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) + return predicted_lddt_ca * 100 + + +class MultimerArch(nn.Cell): + """MultimerArch""" + + def __init__(self, config, mixed_precision): + super(MultimerArch, self).__init__() + + self.cfg = config + + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.recycle_pos = self.cfg.model.recycle_pos + self.recycle_features = self.cfg.model.recycle_features + self.max_relative_feature = self.cfg.model.max_relative_feature + self.num_bins = self.cfg.model.prev_pos.num_bins + self.min_bin = self.cfg.model.prev_pos.min_bin + self.max_bin = self.cfg.model.prev_pos.max_bin + self.use_chain_relative = self.cfg.model.use_chain_relative + self.max_relative_chain = self.cfg.model.max_relative_chain + self.template_enabled = self.cfg.model.template.enabled + self.num_extra_msa = self.cfg.model.num_extra_msa + self.extra_msa_stack_num = self.cfg.model.evoformer.extra_msa_stack_num + self.msa_stack_num = self.cfg.model.evoformer.msa_stack_num + self.chi_atom_indices, self.chi_angles_mask, _, _, \ + self.indices0, self.indices1 = caculate_constant_array(self.cfg.seq_length) + self.pi = np.pi + self.batch_block = 4 + self.preprocess_1d = nn.Dense(21, self.cfg.model.msa_channel, + weight_init=lecun_init(21)) + self.preprocess_msa = nn.Dense(self.cfg.model.common.msa_feat_dim, self.cfg.model.msa_channel, + weight_init=lecun_init(self.cfg.model.common.msa_feat_dim)) + self.left_single = nn.Dense(21, self.cfg.model.pair_channel, + 21) + self.right_single = nn.Dense(21, self.cfg.model.pair_channel, + weight_init=lecun_init(21)) + self.prev_pos_linear = nn.Dense(self.cfg.model.common.dgram_dim, self.cfg.model.pair_channel, + weight_init=lecun_init(self.cfg.model.common.dgram_dim)) + self.extra_msa_one_hot = nn.OneHot(depth=23, axis=-1) + self.template_aatype_one_hot = nn.OneHot(depth=22, axis=-1) + self.prev_msa_first_row_norm = nn.LayerNorm([256,], epsilon=1e-5) + self.prev_pair_norm = nn.LayerNorm([128,], epsilon=1e-5) + if self.use_chain_relative: + self.rel_pos_one_hot = nn.OneHot(depth=self.cfg.model.max_relative_feature * 2 + 2, axis=-1) + self.rel_chain_one_hot = nn.OneHot(depth=self.max_relative_chain * 2 + 2, axis=-1) + self.position_activations = nn.Dense(self.cfg.model.pair_in_dim, self.cfg.model.pair_channel, + weight_init=lecun_init(self.cfg.model.common.pair_in_dim)) + else: + self.one_hot = nn.OneHot(depth=self.cfg.model.max_relative_feature * 2 + 1, axis=-1) + self.position_activations = nn.Dense(self.cfg.model.common.pair_in_dim, self.cfg.model.pair_channel, + weight_init=lecun_init(self.cfg.model.common.pair_in_dim)) + self.extra_msa_activations = nn.Dense(25, self.cfg.model.extra_msa_channel, weight_init=lecun_init(25)) + self.template_embedding = MultimerTemplateEmbedding(self.cfg.model, mixed_precision) + + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.template_single_embedding = nn.Dense(34, self.cfg.model.msa_channel, + weight_init= + lecun_init(34, initializer_name='relu')) + self.template_projection = nn.Dense(self.cfg.model.msa_channel, self.cfg.model.msa_channel, + weight_init=lecun_init(self.cfg.model.msa_channel, + initializer_name='relu')) + self.relu = nn.ReLU() + self.single_activations = nn.Dense(self.cfg.model.msa_channel, self.cfg.model.seq_channel, + weight_init=lecun_init(self.cfg.model.msa_channel)) + extra_msa_stack = nn.CellList() + for _ in range(self.extra_msa_stack_num): + extra_msa_block = MultimerEvoformer(self.cfg.model, + msa_act_dim=64, + pair_act_dim=128, + is_extra_msa=True, + batch_size=None) + extra_msa_stack.append(extra_msa_block) + self.extra_msa_stack = extra_msa_stack + self.msa_stack = MultimerEvoformer(self.cfg.model, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + batch_size=self.msa_stack_num) + self.idx_evoformer_block = Parameter(Tensor(0, mstype.int32), requires_grad=False) + self.evoformer_num_block_eval = Tensor(self.msa_stack_num, mstype.int32) + + self.structure_module = MultimerStructureModule(self.cfg, + self.cfg.model.seq_channel, + self.cfg.model.pair_channel) + + self.module_lddt = PredictedLDDTHead(self.cfg.model.heads.predicted_lddt, + self.cfg.model.seq_channel) + + def construct(self, aatype, residue_index, template_aatype, template_all_atom_mask, template_all_atom_positions, + asym_id, sym_id, entity_id, seq_mask, msa_mask, target_feat, msa_feat, + extra_msa, extra_deletion_matrix, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, prev_pos, prev_msa_first_row, prev_pair): + """construct""" + preprocess_1d = self.preprocess_1d(target_feat) + preprocess_msa = self.preprocess_msa(msa_feat) + msa_activations = mnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + left_single = self.left_single(target_feat) + right_single = self.right_single(target_feat) + pair_activations = P.ExpandDims()(left_single, 1) + P.ExpandDims()(right_single, 0) + mask_2d = P.ExpandDims()(seq_mask, 1) * P.ExpandDims()(seq_mask, 0) + if self.recycle_pos: + prev_pseudo_beta = pseudo_beta_fn(aatype, prev_pos, None) + dgram = dgram_from_positions(prev_pseudo_beta, self.num_bins, self.min_bin, self.max_bin, self._type) + pair_activations += self.prev_pos_linear(dgram) + if self.recycle_features: + prev_msa_first_row = self.prev_msa_first_row_norm(prev_msa_first_row) + msa_activations = mnp.concatenate( + (mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0), msa_activations[1:, ...]), 0) + pair_activations += self.prev_pair_norm(prev_pair) + if self.max_relative_feature: + pair_activations += self._relative_encoding(residue_index, asym_id, sym_id, entity_id) + + if self.template_enabled: + multichain_mask = asym_id[:, None] == asym_id[None, :] + template_pair_representation = self.template_embedding(pair_activations, template_aatype, + template_all_atom_mask, template_all_atom_positions, + mask_2d, multichain_mask) + pair_activations += template_pair_representation + msa_1hot = self.extra_msa_one_hot(extra_msa) + has_deletion = mnp.clip(extra_deletion_matrix, 0., 1.) + deletion_value = (mnp.arctan(extra_deletion_matrix / 3.) * (2. / self.pi)) + extra_msa_feat = mnp.concatenate((msa_1hot, has_deletion[..., None], deletion_value[..., None]), axis=-1) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + extra_msa_mask_tmp = P.Transpose()(P.ExpandDims()(extra_msa_mask, -1), (2, 1, 0)) + extra_msa_norm = P.Transpose()(self.batch_matmul_trans_b(extra_msa_mask_tmp, extra_msa_mask_tmp), (1, 2, 0)) + + for i in range(self.extra_msa_stack_num): + extra_msa_activations, pair_activations = \ + self.extra_msa_stack[i](extra_msa_activations, pair_activations, extra_msa_mask, extra_msa_norm, + mask_2d) + if self.template_enabled: + aatype_one_hot = self.template_aatype_one_hot(template_aatype) + chi_angles, chi_mask = compute_chi_angles(template_aatype, + template_all_atom_positions, + template_all_atom_mask, + self.chi_atom_indices, + self.chi_angles_mask, + self.indices0, + self.indices1, + self.batch_block) + template_features = mnp.concatenate([aatype_one_hot, + mnp.sin(chi_angles) * chi_mask, + mnp.cos(chi_angles) * chi_mask, + chi_mask], axis=-1) + template_mask = chi_mask[:, :, 0] + template_activations = self.template_single_embedding(template_features) + template_activations = self.relu(template_activations) + template_activations = self.template_projection(template_activations) + msa_activations = mnp.concatenate([msa_activations, template_activations], axis=0) + msa_mask = mnp.concatenate([msa_mask, template_mask], axis=0) + msa_mask_tmp = P.Transpose()(P.ExpandDims()(msa_mask, -1), (2, 1, 0)) + msa_mask_norm = P.Transpose()(self.batch_matmul_trans_b(msa_mask_tmp, msa_mask_tmp), (1, 2, 0)) + self.idx_evoformer_block = self.idx_evoformer_block * 0 + while self.idx_evoformer_block < self.evoformer_num_block_eval: + msa_activations, pair_activations = self.msa_stack(msa_activations, + pair_activations, + msa_mask, + msa_mask_norm, + mask_2d, + self.idx_evoformer_block) + self.idx_evoformer_block += 1 + single_activations = self.single_activations(msa_activations[0]) + msa_first_row = msa_activations[0] + final_atom_positions, _, rp_structure_module, _, _, \ + _, _, _, _, _ = \ + self.structure_module(single_activations, + pair_activations, + seq_mask, + aatype, + residx_atom37_to_atom14, + atom37_atom_exists) + predicted_lddt_logits = self.module_lddt(rp_structure_module) + final_atom_positions = P.Cast()(final_atom_positions, self._type) + prev_pos = final_atom_positions + prev_msa_first_row = msa_first_row + prev_pair = pair_activations + return prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits + + def _relative_encoding(self, residue_index, asym_id, sym_id, entity_id): + """Add relative position encoding""" + rel_feats = [] + asym_id_same = mnp.equal(asym_id[:, None], asym_id[None, :]) + offset = residue_index[:, None] - residue_index[None, :] + clipped_offset = mnp.clip( + offset + self.max_relative_feature, xmin=0, xmax=2 * self.max_relative_feature) + + if self.use_chain_relative: + final_offset = mnp.where(asym_id_same, clipped_offset, + (2 * self.max_relative_feature + 1) * + mnp.ones_like(clipped_offset)) + rel_pos = self.rel_pos_one_hot(final_offset) + rel_feats.append(rel_pos) + entity_id_same = mnp.equal(entity_id[:, None], entity_id[None, :]) + rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) + rel_sym_id = sym_id[:, None] - sym_id[None, :] + max_rel_chain = self.max_relative_chain + clipped_rel_chain = mnp.clip( + rel_sym_id + max_rel_chain, xmin=0, xmax=2 * max_rel_chain) + final_rel_chain = mnp.where(entity_id_same, clipped_rel_chain, + (2 * max_rel_chain + 1) * + mnp.ones_like(clipped_rel_chain)) + rel_chain = self.rel_chain_one_hot(final_rel_chain.astype(mstype.int32)) + rel_feats.append(rel_chain) + else: + rel_pos = self.one_hot(clipped_offset) + rel_feats.append(rel_pos) + rel_feat = mnp.concatenate(rel_feats, axis=-1) + return self.position_activations(rel_feat) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f571784cb94561769644010d152485237f577e1e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"pafnucy" +from .pafnucy import PAFNUCY +from .pafnucy_dataset import PAFNUCYDataSet +from .pafnucy_configuration import pafnucy_configuration diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/nn_arch.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/nn_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..caafc379fa6a345f2bd0e502f24b4d42f401465c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/nn_arch.py @@ -0,0 +1,178 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pafnucy Model""" +from math import ceil + +import numpy as np +from mindspore import nn +from mindspore.common import dtype as mstype +from mindspore.common.initializer import TruncatedNormal +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P + + +class HiddenConv3D(nn.Cell): + """HiddenConv3D Cell""" + + def __init__(self, in_channel, out_channel, conv_kernel=5, pool_patch=2, lmbda=0.001): + super(HiddenConv3D, self).__init__() + self.bias_inits = Tensor(np.array([0.1 * out_channel]).astype(np.float32)) + self.conv = nn.Conv3d(in_channels=in_channel, + out_channels=out_channel, + kernel_size=conv_kernel, + stride=1, + pad_mode='same', has_bias=True, weight_init=TruncatedNormal(sigma=lmbda), + bias_init=0.1) + self.relu = nn.ReLU() + self.maxpool3d = P.MaxPool3D(kernel_size=pool_patch, strides=pool_patch, pad_mode='SAME') + + def construct(self, x): + x = self.conv(x) + h = self.relu(x) + h_pool = self.maxpool3d(h) + return h_pool + + +class Conv3DBlock(nn.Cell): + """Conv3DBlock""" + + def __init__(self, in_channel, out_channel, + conv_kernel=5, pool_patch=2, lmbda=0.001): + super(Conv3DBlock, self).__init__() + self.layer1 = HiddenConv3D(in_channel=in_channel[0], + out_channel=out_channel[0], + conv_kernel=conv_kernel, + pool_patch=pool_patch, + lmbda=lmbda) + self.layer2 = HiddenConv3D(in_channel=in_channel[1], + out_channel=out_channel[1], + conv_kernel=conv_kernel, + pool_patch=pool_patch, + lmbda=lmbda) + self.layer3 = HiddenConv3D(in_channel=in_channel[2], + out_channel=out_channel[2], + conv_kernel=conv_kernel, + pool_patch=pool_patch, + lmbda=lmbda) + + def construct(self, x): + x = self.layer1(x) + x = self.layer2(x) + out_c = self.layer3(x) + return out_c + + +class FeedForward(nn.Cell): + """Feed Forward""" + + def __init__(self, fc_size, in_channel, keep_prob=0.5): + super(FeedForward, self).__init__() + self.dense1 = nn.Dense(in_channels=in_channel, + out_channels=fc_size[0], + weight_init=TruncatedNormal(sigma=1 / (in_channel ** 0.5)), + bias_init='one', + has_bias=True, + activation='relu').to_float(mstype.float16) + self.dense2 = nn.Dense(in_channels=fc_size[0], + out_channels=fc_size[1], + weight_init=TruncatedNormal(sigma=1 / (fc_size[0] ** 0.5)), + bias_init='one', + has_bias=True, + activation='relu').to_float(mstype.float16) + self.dense3 = nn.Dense(in_channels=fc_size[1], + out_channels=fc_size[2], + weight_init=TruncatedNormal(sigma=1 / (fc_size[1] ** 0.5)), + bias_init='one', + has_bias=True, + activation='relu').to_float(mstype.float16) + self.dropout = nn.Dropout(keep_prob=keep_prob) + self.dropout1 = nn.Dropout(keep_prob=1.) + + def construct(self, x, prob=False): + """construct""" + + x = self.dense1(x) + if prob: + x = self.dropout1(x) + else: + x = self.dropout(x) + x = self.dense2(x) + if prob: + x = self.dropout1(x) + else: + x = self.dropout(x) + out = self.dense3(x) + if prob: + out_f = self.dropout1(out) + else: + out_f = self.dropout(out) + return out_f + + +class SBNetWork(nn.Cell): + """SB network""" + + def __init__(self, in_channel=None, + out_channel=None, + dense_size=None, + osize=1, lmbda=0.01, isize=20, conv_kernel=5, + pool_patch=2, keep_prob=0.5, + is_training=True): + super(SBNetWork, self).__init__() + self.conv3dblock = Conv3DBlock(in_channel, out_channel, + conv_kernel, pool_patch, lmbda) + self.hfsize = isize + self.out_channel = out_channel + self.is_training = is_training + for _ in range(len(self.out_channel)): + self.hfsize = ceil(self.hfsize / pool_patch) + self.hfsize = self.out_channel[-1] * self.hfsize ** 3 + self.reshape = P.Reshape() + self.feedforward = FeedForward(dense_size, self.hfsize, keep_prob=keep_prob).to_float(mstype.float16) + self.out_dense = nn.Dense(in_channels=dense_size[2], + out_channels=osize, + weight_init=TruncatedNormal(sigma=(1 / (dense_size[2] ** 0.5))), + bias_init='one', + has_bias=True, + activation='relu') + self.reduce_mean = P.ReduceMean() + self.pow = P.Pow() + self.mse = nn.MSELoss() + self.rmse = nn.RMSELoss() + self.cast = P.Cast() + + def construct(self, x, target=None, prob=False): + """construct""" + + x = self.conv3dblock(x) + h_flat = self.reshape(x, (-1, self.hfsize)) + h_flat = self.cast(h_flat, mstype.float16) + h_fc = self.feedforward(h_flat, prob=prob) + h_fc = self.cast(h_fc, mstype.float32) + y = self.out_dense(h_fc) + if self.is_training: + mse_out = self.mse(y, target) + else: + mse_out = y + + return mse_out diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a33dc9b9849efce11b0d5fe09b63e2317af2d7 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy.py @@ -0,0 +1,128 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pafnucy Model""" +import time + +import mindspore as ms +from mindspore import Tensor, context, jit, nn +from mindspore.common import dtype as mstype +from mindspore.common import mutable +from mindspore.nn import TrainOneStepCell + +from ..model import Model +from .nn_arch import SBNetWork + + +class PAFNUCY(Model): + """pafnucy model""" + def __init__(self, config): + context.set_context(memory_optimize_level="O1", max_call_depth=6000) + self.config = config + self.use_jit = self.config.use_jit + self.is_training = self.config.is_training + self.checkpoint_url = 'https://download.mindspore.cn/mindscience/mindsponge/Pafnucy/checkpoint/pafnucy.ckpt' + self.checkpoint_path = "./pafnucy.ckpt" + if self.is_training: + self.network = SBNetWork(in_channel=[19, 64, 128], + out_channel=self.config.conv_channels, + dense_size=self.config.dense_sizes, + lmbda=self.config.lmbda, + isize=self.config.isize, keep_prob=self.config.keep_prob) + self.lr = Tensor(float(self.config.lr), mstype.float32) + optimizer = nn.Adam(params=self.network.trainable_params(), + learning_rate=self.lr, weight_decay=self.config.weight_decay) + self.train_wrapper = TrainOneStepCell(self.network, optimizer=optimizer) + self.network.set_train() + else: + self.network = SBNetWork(in_channel=[19, 64, 128], + out_channel=config.conv_channels, + dense_size=config.dense_sizes, + lmbda=config.lmbda, + isize=config.isize, keep_prob=1.0) + self.network.set_train(False) + + super().__init__(self.checkpoint_url, self.network) + + + def forward(self, data): + if self.use_jit: + result = self._jit_forward(data) + else: + result = self._pynative_forward(data) + return result + + + def predict(self, test_data): + """predict""" + feat = [] + data = {} + data["coords_feature"] = Tensor(test_data["coords_feature"], ms.float32) + if len(data.get("coords_feature").shape) == 4: + data["coords_feature"] = data.get("coords_feature").expand_dims(axis=0) + data["affinity"] = Tensor(test_data.get("affinity"), ms.float32) + + feat.append(data.get("coords_feature")) + feat.append(data.get("affinity")) + feat = mutable(feat) + + t1 = time.time() + result = self.forward(feat) + t2 = time.time() + print(round(t2 - t1, 2)) + return result + + + def loss(self, data): + pass + + + def grad_operations(self, gradient): + pass + + + @jit + def backward(self, feat): + loss = self.train_wrapper(*feat) + return loss + + + def train_step(self, data): + """train step""" + feat = [] + feat.append(Tensor(data.get("coords_feature"), ms.float32)) + feat.append(Tensor(data.get("affinity"), ms.float32)) + feat.append(Tensor(data.get("rot")[-1])) + feat = mutable(feat) + loss = self.backward(feat) + return loss + + + @jit + def _jit_forward(self, data): + result = self.network(*data) + return result + + + def _pynative_forward(self, data): + result = self.network(*data) + return result diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_configuration.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..8c34307c4dc82618be4061cbf548f94cfb4ccc42 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_configuration.py @@ -0,0 +1,26 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pafnucy configuration""" +pafnucy_configuration = { + "config": "https://download.mindspore.cn/mindscience/mindsponge/Pafnucy/config/pafnucy.yaml" +} diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_data.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_data.py new file mode 100644 index 0000000000000000000000000000000000000000..5330db69979c6ec9113d2a61afd9c5305534bb71 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_data.py @@ -0,0 +1,701 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pafnucy data""" +import os +import pickle +import stat +from itertools import combinations +from math import ceil, cos, pi, sin, sqrt + +import numpy as np +import pandas as pd +from openbabel import pybel + + +# pylint: disable=invalid-name +class Featurizer(): + """Calcaulates atomic features for molecules. Features can encode atom type, + native pybel properties or any property defined with SMARTS patterns + + Attributes + ---------- + FEATURE_NAMES: list of strings + Labels for features (in the same order as features) + NUM_ATOM_CLASSES: int + Number of atom codes + ATOM_CODES: dict + Dictionary mapping atomic numbers to codes + NAMED_PROPS: list of string + Names of atomic properties to retrieve from pybel.Atom object + CALLABLES: list of callables + Callables used to calculate custom atomic properties + SMARTS: list of SMARTS strings + SMARTS patterns defining additional atomic properties + """ + + def __init__(self, atom_codes=None, atom_labels=None, + named_properties=None, save_molecule_codes=True, + custom_properties=None, smarts_properties=None, + smarts_labels=None): + + """Creates Featurizer with specified types of features. Elements of a + feature vector will be in a following order: atom type encoding + (defined by atom_codes), Pybel atomic properties (defined by + named_properties), molecule code (if present), custom atomic properties + (defined `custom_properties`), and additional properties defined with + SMARTS (defined with `smarts_properties`). + + Parameters + ---------- + atom_codes: dict, optional + Dictionary mapping atomic numbers to codes. It will be used for + one-hot encoging therefore if n different types are used, codes + shpuld be from 0 to n-1. Multiple atoms can have the same code, + e.g. you can use {6: 0, 7: 1, 8: 1} to encode carbons with [1, 0] + and nitrogens and oxygens with [0, 1] vectors. If not provided, + default encoding is used. + atom_labels: list of strings, optional + Labels for atoms codes. It should have the same length as the + number of used codes, e.g. for `atom_codes={6: 0, 7: 1, 8: 1}` you + should provide something like ['C', 'O or N']. If not specified + labels 'atom0', 'atom1' etc are used. If `atom_codes` is not + specified this argument is ignored. + named_properties: list of strings, optional + Names of atomic properties to retrieve from pybel.Atom object. If + not specified ['hyb', 'heavyvalence', 'heterovalence', + 'partialcharge'] is used. + save_molecule_codes: bool, optional (default True) + If set to True, there will be an additional feature to save + molecule code. It is usefeul when saving molecular complex in a + single array. + custom_properties: list of callables, optional + Custom functions to calculate atomic properties. Each element of + this list should be a callable that takes pybel.Atom object and + returns a float. If callable has `__name__` property it is used as + feature label. Otherwise labels 'func' etc are used, where i is + the index in `custom_properties` list. + smarts_properties: list of strings, optional + Additional atomic properties defined with SMARTS patterns. These + patterns should match a single atom. If not specified, default + patterns are used. + smarts_labels: list of strings, optional + Labels for properties defined with SMARTS. Should have the same + length as `smarts_properties`. If not specified labels 'smarts0', + 'smarts1' etc are used. If `smarts_properties` is not specified + this argument is ignored. + """ + + # Remember namse of all features in the correct order + # pylint: disable=invalid-name + self.FEATURE_NAMES = [] + # pylint: disable=invalid-name + self.__PATTERNS = [] + if atom_codes is not None: + if not isinstance(atom_codes, dict): + raise TypeError('Atom codes should be dict, got %s instead' + % type(atom_codes)) + codes = set(atom_codes.values()) + for i in range(len(codes)): + if i not in codes: + raise ValueError('Incorrect atom code %s' % i) + + # pylint: disable=invalid-name + self.NUM_ATOM_CLASSES = len(codes) + # pylint: disable=invalid-name + self.ATOM_CODES = atom_codes + if atom_labels is not None: + if len(atom_labels) != self.NUM_ATOM_CLASSES: + raise ValueError('Incorrect number of atom labels: ' + '%s instead of %s' + % (len(atom_labels), self.NUM_ATOM_CLASSES)) + else: + atom_labels = ['atom%s' % i for i in range(self.NUM_ATOM_CLASSES)] + self.FEATURE_NAMES += atom_labels + else: + self.ATOM_CODES = {} + + metals = ([3, 4, 11, 12, 13] + list(range(19, 32)) + + list(range(37, 51)) + list(range(55, 84)) + + list(range(87, 104))) + + # List of tuples (atomic_num, class_name) with atom types to encode. + atom_classes = [ + (5, 'B'), + (6, 'C'), + (7, 'N'), + (8, 'O'), + (15, 'P'), + (16, 'S'), + (34, 'Se'), + ([9, 17, 35, 53], 'halogen'), + (metals, 'metal') + ] + + for code, (atom, name) in enumerate(atom_classes): + if isinstance(atom, list): + for a in atom: + self.ATOM_CODES[a] = code + else: + self.ATOM_CODES[atom] = code + self.FEATURE_NAMES.append(name) + + self.NUM_ATOM_CLASSES = len(atom_classes) + + if named_properties is not None: + if not isinstance(named_properties, (list, tuple, np.ndarray)): + raise TypeError('named_properties must be a list') + allowed_props = [prop for prop in dir(pybel.Atom) + if not prop.startswith('__')] + for prop_id, prop in enumerate(named_properties): + if prop not in allowed_props: + raise ValueError( + 'named_properties must be in pybel.Atom attributes,' + ' %s was given at position %s' % (prop_id, prop) + ) + # pylint: disable=invalid-name + self.NAMED_PROPS = named_properties + else: + self.NAMED_PROPS = ['hyb', 'heavydegree', 'heterodegree', + 'partialcharge'] + + self.FEATURE_NAMES += self.NAMED_PROPS + + if not isinstance(save_molecule_codes, bool): + raise TypeError('save_molecule_codes should be bool, got %s ' + 'instead' % type(save_molecule_codes)) + self.save_molecule_codes = save_molecule_codes + if save_molecule_codes: + # Remember if an atom belongs to the ligand or to the protein + self.FEATURE_NAMES.append('molcode') + + # pylint: disable=invalid-name + self.CALLABLES = [] + if custom_properties is not None: + for i, func in enumerate(custom_properties): + if not callable(func): + raise TypeError('custom_properties should be list of' + ' callables, got %s instead' % type(func)) + name = getattr(func, '__name__', '') + if name == '': + name = 'func%s' % i + self.CALLABLES.append(func) + self.FEATURE_NAMES.append(name) + + if smarts_properties is None: + # SMARTS definition for other properties + # pylint: disable=invalid-name + self.SMARTS = [ + '[#6+0!$(*~[#7,#8,F]),SH0+0v2,s+0,S^3,Cl+0,Br+0,I+0]', + '[a]', + '[!$([#1,#6,F,Cl,Br,I,o,s,nX3,#7v5,#15v5,#16v4,#16v6,*+1,*+2,*+3])]', + '[!$([#6,H0,-,-2,-3]),$([!H0;#7,#8,#9])]', + '[r]' + ] + smarts_labels = ['hydrophobic', 'aromatic', 'acceptor', 'donor', + 'ring'] + elif not isinstance(smarts_properties, (list, tuple, np.ndarray)): + raise TypeError('smarts_properties must be a list') + else: + self.SMARTS = smarts_properties + + if smarts_labels is not None: + if len(smarts_labels) != len(self.SMARTS): + raise ValueError('Incorrect number of SMARTS labels: %s' + ' instead of %s' + % (len(smarts_labels), len(self.SMARTS))) + else: + smarts_labels = ['smarts%s' % i for i in range(len(self.SMARTS))] + + # Compile patterns + self.compile_smarts() + self.FEATURE_NAMES += smarts_labels + + @staticmethod + def from_pickle(fname): + """Load pickled featurizer from a given file + + Parameters + ---------- + fname: str, optional + Path to file with saved featurizer + + Returns + ------- + featurizer: Featurizer object + Loaded featurizer + """ + + with open(fname, 'rb') as f: + featurizer = pickle.load(f) + featurizer.compile_smarts() + return featurizer + + # pylint: disable=invalid-name + def compile_smarts(self): + self.__PATTERNS = [] + for smarts in self.SMARTS: + self.__PATTERNS.append(pybel.Smarts(smarts)) + + def encode_num(self, atomic_num): + """Encode atom type with a binary vector. If atom type is not included in + the `atom_classes`, its encoding is an all-zeros vector. + + Parameters + ---------- + atomic_num: int + Atomic number + + Returns + ------- + encoding: np.ndarray + Binary vector encoding atom type (one-hot or null). + """ + + if not isinstance(atomic_num, int): + raise TypeError('Atomic number must be int, %s was given' + % type(atomic_num)) + + encoding = np.zeros(self.NUM_ATOM_CLASSES) + try: + # pylint: disable=get-dict-value-exception + encoding[self.ATOM_CODES[atomic_num]] = 1.0 + #pylint: disable=bare-except + except: + pass + return encoding + + def find_smarts(self, molecule): + """Find atoms that match SMARTS patterns. + + Parameters + ---------- + molecule: pybel.Molecule + + Returns + ------- + features: np.ndarray + NxM binary array, where N is the number of atoms in the `molecule` + and M is the number of patterns. `features[i, j]` == 1.0 if i'th + atom has j'th property + """ + + if not isinstance(molecule, pybel.Molecule): + raise TypeError('molecule must be pybel.Molecule object, %s was given' + % type(molecule)) + + features = np.zeros((len(molecule.atoms), len(self.__PATTERNS))) + + for (pattern_id, pattern) in enumerate(self.__PATTERNS): + atoms_with_prop = np.array(list(*zip(*pattern.findall(molecule))), + dtype=int) - 1 + features[atoms_with_prop, pattern_id] = 1.0 + return features + + def get_features(self, molecule, molcode=None): + """Get coordinates and features for all heavy atoms in the molecule. + + Parameters + ---------- + molecule: pybel.Molecule + molcode: float, optional + Molecule type. You can use it to encode whether an atom belongs to + the ligand (1.0) or to the protein (-1.0) etc. + + Returns + ------- + coords: np.ndarray, shape = (N, 3) + Coordinates of all heavy atoms in the `molecule`. + features: np.ndarray, shape = (N, F) + Features of all heavy atoms in the `molecule`: atom type + (one-hot encoding), pybel.Atom attributes, type of a molecule + (e.g protein/ligand distinction), and other properties defined with + SMARTS patterns + """ + + if not isinstance(molecule, pybel.Molecule): + raise TypeError('molecule must be pybel.Molecule object,' + ' %s was given' % type(molecule)) + if molcode is None: + if self.save_molecule_codes is True: + raise ValueError('save_molecule_codes is set to True,' + ' you must specify code for the molecule') + elif not isinstance(molcode, (float, int)): + raise TypeError('motlype must be float, %s was given' + % type(molcode)) + + coords = [] + features = [] + heavy_atoms = [] + + for i, atom in enumerate(molecule): + # ignore hydrogens and dummy atoms (they have atomicnum set to 0) + if atom.atomicnum > 1: + heavy_atoms.append(i) + coords.append(atom.coords) + + features.append(np.concatenate(( + self.encode_num(atom.atomicnum), + [atom.__getattribute__(prop) for prop in self.NAMED_PROPS], + [func(atom) for func in self.CALLABLES], + ))) + + coords = np.array(coords, dtype=np.float32) + features = np.array(features, dtype=np.float32) + if self.save_molecule_codes: + features = np.hstack((features, + molcode * np.ones((len(features), 1)))) + features = np.hstack([features, + self.find_smarts(molecule)[heavy_atoms]]) + + if np.isnan(features).any(): + raise RuntimeError('Got NaN when calculating features') + + return coords, features + + def to_pickle(self, fname='featurizer.pkl'): + """Save featurizer in a given file. Featurizer can be restored with + `from_pickle` method. + + Parameters + ---------- + fname: str, optional + Path to file in which featurizer will be saved + """ + + # patterns can't be pickled, we need to temporarily remove them + patterns = self.__PATTERNS[:] + del self.__PATTERNS + flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + modes = stat.S_IWUSR | stat.S_IRUSR + try: + with os.fdopen(os.open(fname, flags, modes), 'wb') as fout: + pickle.dump(self, fout) + finally: + self.__PATTERNS = patterns[:] + + +def rotation_matrix(in_axis, in_theta): + """Counterclockwise rotation about a given axis by theta radians""" + + if not isinstance(in_axis, (np.ndarray, list, tuple)): + raise TypeError('axis must be an array of floats of shape (3,)') + try: + in_axis = np.asarray(in_axis, dtype=np.float) + except ValueError: + raise ValueError('axis must be an array of floats of shape (3,)') + + if in_axis.shape != (3,): + raise ValueError('axis must be an array of floats of shape (3,)') + + if not isinstance(in_theta, (float, int)): + raise TypeError('theta must be a float') + + in_axis = in_axis / sqrt(np.dot(in_axis, in_axis)) + a = cos(in_theta / 2.0) + b, c, d = -in_axis * sin(in_theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + + +# Create matrices for all possible 90* rotations of a box +ROTATIONS = [rotation_matrix([1, 1, 1], 0)] + +# about X, Y and Z - 9 rotations +for a1 in range(3): + for t in range(1, 4): + axis = np.zeros(3) + axis[a1] = 1 + theta = t * pi / 2.0 + ROTATIONS.append(rotation_matrix(axis, theta)) + +# about each face diagonal - 6 rotations +for (a1, a2) in combinations(range(3), 2): + axis = np.zeros(3) + axis[[a1, a2]] = 1.0 + theta = pi + ROTATIONS.append(rotation_matrix(axis, theta)) + axis[a2] = -1.0 + ROTATIONS.append(rotation_matrix(axis, theta)) + +# about each space diagonal - 8 rotations +for t in [1, 2]: + theta = t * 2 * pi / 3 + axis = np.ones(3) + ROTATIONS.append(rotation_matrix(axis, theta)) + for a1 in range(3): + axis = np.ones(3) + axis[a1] = -1 + ROTATIONS.append(rotation_matrix(axis, theta)) + + +def rotate(coords, rotation): + """Rotate coordinates by a given rotation + + Parameters + ---------- + coords: array-like, shape (N, 3) + Arrays with coordinates and features for each atoms. + rotation: int or array-like, shape (3, 3) + Rotation to perform. You can either select predefined rotation by + giving its index or specify rotation matrix. + + Returns + ------- + coords: np.ndarray, shape = (N, 3) + Rotated coordinates. + """ + + if not isinstance(coords, (np.ndarray, list, tuple)): + raise TypeError('coords must be an array of floats of shape (N, 3)') + try: + coords = np.asarray(coords, dtype=np.float) + except ValueError: + raise ValueError('coords must be an array of floats of shape (N, 3)') + shape = coords.shape + if len(shape) != 2 or shape[1] != 3: + raise ValueError('coords must be an array of floats of shape (N, 3)') + + if isinstance(rotation, int): + if 0 <= rotation < len(ROTATIONS): + out = np.dot(coords, ROTATIONS[rotation]) + else: + raise ValueError('Invalid rotation number %s!' % rotation) + elif isinstance(rotation, np.ndarray) and rotation.shape == (3, 3): + out = np.dot(coords, rotation) + else: + raise ValueError('Invalid rotation %s!' % rotation) + return out + + +# pylint: disable=invalid-name +def make_grid(coords, features, grid_resolution=1.0, max_dist=10.0): + """Convert atom coordinates and features represented as 2D arrays into a + fixed-sized 3D box. + + Parameters + ---------- + coords, features: array-likes, shape (N, 3) and (N, F) + Arrays with coordinates and features for each atoms. + grid_resolution: float, optional + Resolution of a grid (in Angstroms). + max_dist: float, optional + Maximum distance between atom and box center. Resulting box has size of + 2*`max_dist`+1 Angstroms and atoms that are too far away are not + included. + + Returns + ------- + coords: np.ndarray, shape = (M, M, M, F) + 4D array with atom properties distributed in 3D space. M is equal to + 2 * `max_dist` / `grid_resolution` + 1 + """ + + try: + coords = np.asarray(coords, dtype=np.float) + except ValueError: + raise ValueError('coords must be an array of floats of shape (N, 3)') + c_shape = coords.shape + if len(c_shape) != 2 or c_shape[1] != 3: + raise ValueError('coords must be an array of floats of shape (N, 3)') + + N = len(coords) + try: + features = np.asarray(features, dtype=np.float) + except ValueError: + raise ValueError('features must be an array of floats of shape (N, F)') + f_shape = features.shape + if len(f_shape) != 2 or f_shape[0] != N: + raise ValueError('features must be an array of floats of shape (N, F)') + + if not isinstance(grid_resolution, (float, int)): + raise TypeError('grid_resolution must be float') + if grid_resolution <= 0: + raise ValueError('grid_resolution must be positive') + + if not isinstance(max_dist, (float, int)): + raise TypeError('max_dist must be float') + if max_dist <= 0: + raise ValueError('max_dist must be positive') + + num_features = f_shape[1] + max_dist_ = float(max_dist) + grid_resolution_ = float(grid_resolution) + + box_size = ceil(2 * max_dist_ / grid_resolution_ + 1) + + # move all atoms to the nearest grid point + grid_coords = (coords + max_dist_) / grid_resolution_ + grid_coords = grid_coords.round().astype(int) + + # remove atoms outside the box + in_box = ((grid_coords >= 0) & (grid_coords < box_size)).all(axis=1) + grid = np.zeros((1, box_size, box_size, box_size, num_features), + dtype=np.float32) + for (x, y, z), f in zip(grid_coords[in_box], features[in_box]): + grid[0, x, y, z] += f + + return grid + + +def extractfeature(pocket, ligand): + """extract features""" + featurizer = Featurizer() + charge_idx = featurizer.FEATURE_NAMES.index('partialcharge') + + ligand_coords, ligand_features = featurizer.get_features(ligand, molcode=1) + assert (ligand_features[:, charge_idx] != 0).any() + pocket_coords, pocket_features = featurizer.get_features(pocket, molcode=-1) + + centroid = ligand_coords.mean(axis=0) + ligand_coords -= centroid + pocket_coords -= centroid + + data = np.concatenate((np.concatenate((ligand_coords, pocket_coords)), + np.concatenate((ligand_features, pocket_features))), axis=1) + return data + + +def preprocess(coords, features, config, std, rotation=0): + """preprocess""" + x = [] + featurizer = Featurizer() + + columns = {name: i for i, name in enumerate(featurizer.FEATURE_NAMES)} + + coords_rot = rotate(coords, rotation) + x.append(make_grid(coords_rot, features, grid_resolution=config.grid_spacing, max_dist=config.max_dist)) + x = np.vstack(x) + x[..., columns['partialcharge']] /= std + x = np.transpose(np.squeeze(x), axes=(3, 0, 1, 2)) + return x + + +def extrct2013ids(in_paths): + """Extract pdbbind2013 index""" + flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + modes = stat.S_IWUSR | stat.S_IRUSR + filepath = os.path.join(in_paths, './v2013-core') + file_idx = os.listdir(filepath) + for items in file_idx: + with os.fdopen(os.open(os.path.join(in_paths, 'core_pdbbind2013.ids'), flags, modes), 'a') as fout: + fout.write(items+'\n') + print("extract 2013 index done!") + + +def parseandclean(paths): + """parse and clean""" + + files = os.path.join( + paths, 'PDBbind_2016_plain_text_index/index/INDEX_general_PL_data.2016') + if os.path.exists('./affinity_data.csv'): + os.remove('./affinity_data.csv') + # Save binding affinities to csv file + result = pd.DataFrame(columns=('pdbid', 'Kd_Ki')) + for line in open(files): + line = line.rstrip() + if line.startswith('#') or line == '': + continue + it = line.split(maxsplit=7) + pdbid, log_kdki = it[0], it[3] + result = result.append( + pd.DataFrame({'pdbid': [pdbid], 'Kd_Ki': [log_kdki]}), + ignore_index=True) + result.to_csv('affinity_data.csv', sep=",", index=False) + affinity_data = pd.read_csv('affinity_data.csv', comment='#') + + # Find affinities without structural data (i.e. with missing directories) + missing_ = [] + for misdata in affinity_data['pdbid']: + gser = os.path.join(paths, f'general-set-except-refined/{misdata}') + refined_set = os.path.join(paths, f'refined-set/{misdata}') + if not os.path.exists(gser) and not os.path.exists(refined_set): + missing_.append(misdata) + missing = set(missing_) + affinity_data = affinity_data[~np.in1d( + affinity_data['pdbid'], list(missing))] + print("Missing length: ", len(missing)) + print(affinity_data['Kd_Ki'].isnull().any()) + + # Separate core, refined, and general sets + core_file = os.path.join( + paths, 'PDBbind_2016_plain_text_index/index/INDEX_core_data.2016') + core_set_ = [] + for c_line in open(core_file): + c_line = c_line.rstrip() + if c_line.startswith('#') or c_line == '': + continue + c_it = c_line.split(maxsplit=7) + core_set_.append(c_it[0]) + core_set = set(core_set_) + print('Core Set length: ', len(core_set)) + refined_file = os.path.join( + paths, 'PDBbind_2016_plain_text_index/index/INDEX_refined_data.2016') + refined_set_ = [] + for rf_line in open(refined_file): + rf_line = rf_line.rstrip() + if rf_line.startswith('#') or rf_line == '': + continue + rf_it = rf_line.split(maxsplit=7) + refined_set_.append(rf_it[0]) + refined_set = set(refined_set_) + general_set = set(affinity_data['pdbid']) + + assert core_set & refined_set == core_set + assert refined_set & general_set == refined_set + + print("Refined Set Length: ", len(refined_set)) + print("General Set Length: ", len(general_set)) + # exclude v2013 core set -- it will be used as another test set + core2013_file = os.path.join(paths, 'core_pdbbind2013.ids') + core2013_ = [] + for c2_line in open(core2013_file): + c2_it = c2_line.rstrip() + core2013_.append(c2_it) + core2013 = set(core2013_) + print("Core2013 length: ", len(core2013)) + print(affinity_data.head()) + print(len(core2013 & (general_set - core_set))) + affinity_data['include'] = True + affinity_data.loc[np.in1d(affinity_data['pdbid'], list( + core2013 & (general_set - core_set))), 'include'] = False + + affinity_data.loc[np.in1d(affinity_data['pdbid'], + list(general_set)), 'set'] = 'general' + affinity_data.loc[np.in1d(affinity_data['pdbid'], + list(refined_set)), 'set'] = 'refined' + affinity_data.loc[np.in1d(affinity_data['pdbid'], + list(core_set)), 'set'] = 'core' + + print(affinity_data.head()) + print(affinity_data[affinity_data['include']].groupby( + 'set').apply(len).loc[['general', 'refined', 'core']]) + + if os.path.exists('./affinity_data_cleaned.csv'): + os.remove('./affinity_data_cleaned.csv') + affinity_data[['pdbid']].to_csv('pdb.ids', header=False, index=False) + affinity_data[['pdbid', 'Kd_Ki', 'set']].to_csv( + 'affinity_data_cleaned.csv', index=False) + return affinity_data diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_dataset.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1a7682572f0d3b8cca40f0414de1f7edc2a8a7 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/models/pafnucy/pafnucy_dataset.py @@ -0,0 +1,206 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pafnucy data""" +import os +import warnings + +import numpy as np +from mindspore.dataset import GeneratorDataset +from openbabel import openbabel as ob +from openbabel import pybel +from sklearn.utils import shuffle + +from ...dataset import PDBBind +from .pafnucy_data import (extrct2013ids, parseandclean, extractfeature, + preprocess) + +ob.obErrorLog.SetOutputLevel(0) + + +class PAFNUCYDataSet(PDBBind): + """pafnucy dataset""" + def __init__(self, config): + self.config = config + self.is_training = self.config.is_training + self.std = 0.19213134 # given by shuffle seed 123 + self.data_size = 0 + self.pdbs = {"general": [], "refined": [], "core": []} + self.labels = {} + self.schemalist = ["coords_feature", "affinity", "rot"] + self.general_data_src = "" + self.refine_data_src = "" + self.general_pdbids = [] + self.refine_pdbids = [] + self.training_pdbids = [] + self.training_size = 0 + super().__init__() + + + def __getitem__(self, idx): + data, label = self.data_parse(idx=idx) + rot = False + if self.is_training: + assert self.training_size != 0 + rotation = idx // self.training_size + if rotation >= self.config.rotations: + rotation = 0 + rot = True + else: + rot = False + rotation = rotation.item() + else: + rotation = 0 + features = self.process(data, label, rotation, rot) + tuple_feature = tuple([features.get(key) for key in self.schemalist]) + return tuple_feature + + + def __len__(self): + data_len = self.training_size * (self.config.rotations + 1) + return data_len + + + def get_path(self, pdbid, pdbset): + "get path" + if pdbset == "general": + ligand_path = self.general_data_src + pdbid + f"/{pdbid}_ligand.mol2" + if os.path.exists(self.general_data_src + pdbid + f"/{pdbid}_pocket.mol2"): + pocket_path = self.general_data_src + pdbid + f"/{pdbid}_pocket.mol2" + else: + pocket_path = self.general_data_src + pdbid + f"/{pdbid}_pocket.pdb" + molfile = pocket_path.replace(".pdb", ".mol2") + command = "obabel -i pdb %s -o mol2 -O %s" % (pocket_path, molfile) + os.system(command) + pocket_path = molfile + else: + ligand_path = self.refine_data_src + pdbid + f"/{pdbid}_ligand.mol2" + if os.path.exists(self.refine_data_src + pdbid + f"/{pdbid}_pocket.mol2"): + pocket_path = self.refine_data_src + pdbid + f"/{pdbid}_pocket.mol2" + else: + pocket_path = self.refine_data_src + pdbid + f"/{pdbid}_pocket.pdb" + molfile = pocket_path.replace(".pdb", ".mol2") + command = "obabel -i pdb %s -o mol2 -O %s" % (pocket_path, molfile) + os.system(command) + pocket_path = molfile + return ligand_path, pocket_path + + # pylint: disable=arguments-differ + def process(self, data, label=None, rotation=0, rot=False): + """data process""" + assert len(data) == 2 + pocket = data[0] + ligand = data[1] + + feature = extractfeature(pocket, ligand) + coords = feature[:, :3] + features = feature[:, 3:] + coords_feature = preprocess(coords, features, self.config, self.std, rotation=rotation) + coords_feature = np.array(coords_feature, dtype=np.float32) + if label is not None: + affinity = label + else: + affinity = -1 + return {"coords_feature": coords_feature, "affinity": affinity, "rot": rot} + + + def download(self, path=None): + pass + + + def data_parse(self, input_data=None, idx=0): + """data parse""" + if input_data is None: + pdbid = self.training_pdbids[idx][0] + pdbset = self.training_pdbids[idx][1] + else: + pdbid = input_data[0] + pdbset = input_data[1] + assert pdbset in ["general", "refined"] + ligand_path, pocket_path = self.get_path(pdbid, pdbset) + ligand = next(pybel.readfile('mol2', ligand_path)) + try: + pocket = next(pybel.readfile('mol2', pocket_path)) + except ValueError: + warnings.warn('no pocket available.') + label = self.labels.get(pdbid) + data = [pocket, ligand] + return data, label + + + def set_training_data_src(self, data_src=None): + """set training data src""" + if data_src is None: + data_src = self.cache + cmd = "cp {data_src}/index/INDEX_core_data.2016 {data_src}/PDBbind_2016_plain_text_index/index/" + os.system(cmd) + print("Start preprocessing PDBBind data ... ") + if not os.path.exists(os.path.join(data_src, 'PDBbind_2016_plain_text_index/index/INDEX_general_PL_data.2016')): + raise IOError("INDEX_general_PL_data.2016 file doesn't exit!") + if not os.path.exists(os.path.join(data_src, 'PDBbind_2016_plain_text_index/index/INDEX_core_data.2016')): + raise IOError("INDEX_core_data.2016 file doesn't exit!") + if not os.path.exists(os.path.join(data_src, 'PDBbind_2016_plain_text_index/index/INDEX_refined_data.2016')): + raise IOError("INDEX_refined_data.2016 file doesn't exit!") + if os.path.exists(os.path.join(data_src, 'core_pdbbind2013.ids')): + print("Remove Exist core_pdbbind2013.ids file.") + os.remove(os.path.join(data_src, 'core_pdbbind2013.ids')) + + self.general_data_src = data_src + "general-set-except-refined/" + self.refine_data_src = data_src + "refined-set/" + + extrct2013ids(data_src) + affinity_data = parseandclean(data_src) + self.data_size = len(affinity_data) + for i in range(self.data_size): + pdbid = affinity_data.iloc[i, 0] + pdbset = affinity_data.iloc[i, 3] + ligand_path, pocket_path = self.get_path(pdbid, pdbset) + ligand = next(pybel.readfile('mol2', ligand_path)) + try: + pocket = next(pybel.readfile('mol2', pocket_path)) + except ValueError: + print(ValueError) + continue + if ligand is None or pocket is None: + continue + + if affinity_data.iloc[i, 2]: + self.pdbs[pdbset].append([pdbid, pdbset]) + else: + self.pdbs["core"].append([pdbid, "refined"]) + self.labels[pdbid] = affinity_data.iloc[i, 1] + + self.general_pdbids = self.pdbs.get("general") + self.refine_pdbids = self.pdbs.get("refined") + + refined_shuffled = shuffle(self.refine_pdbids, random_state=123) + self.training_pdbids = self.general_pdbids + refined_shuffled[self.config.size_val:] + self.training_size = len(self.training_pdbids) + self.training_pdbids *= (self.config.rotations + 1) + + + def create_iterator(self, num_epochs): + dataset = GeneratorDataset(source=self, column_names=self.schemalist, + num_parallel_workers=4, shuffle=False, max_rowsize=16) + dataset = dataset.batch(batch_size=20, drop_remainder=True) + iteration = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True) + return iteration diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/pipeline.py b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a200270d28d4164a5f4fe7097958117c105acc73 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/pipeline/pipeline.py @@ -0,0 +1,82 @@ +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Pipeline""" +import os +import time +import ssl +import urllib.request +from mindspore import context +from ..common.config_load import load_config +from .models import Multimer, MultimerDataSet, multimer_configuration +from .models import COLABDESIGN, ColabDesignDataSet, colabdesign_configuration + +model_card = { + "Multimer": {"model": Multimer, "dataset": MultimerDataSet, "config": multimer_configuration}, + "ColabDesign": {"model": COLABDESIGN, "dataset": ColabDesignDataSet, "config": colabdesign_configuration}, +} + + +def download_config(url, save_path): + if not os.path.exists(save_path): + prefix, _ = os.path.split(save_path) + if not os.path.exists(prefix): + os.makedirs(prefix) + print("Download config to ", save_path) + ssl._create_default_https_context = ssl._create_unverified_context + urllib.request.urlretrieve(url, save_path) + config = load_config(save_path) + return config + + +class PipeLine: + """PipeLine""" + + def __init__(self, name): + self.model_cls = model_card[name]["model"] + self.dataset_cls = model_card[name]["dataset"] + self.config = model_card[name]["config"] + self.model = None + self.dataset = None + self.config_path = "./config/" + + def initialize(self, key): + config = download_config(self.config[key], self.config_path + key + ".yaml") + self.model = self.model_cls(config) + self.dataset = self.dataset_cls(config) + + def set_device_id(self, device_id): + context.set_context(device_id=device_id) + + def predict(self, data): + data = self.dataset.process(data) + result = self.model.predict(data) + return result + + def train(self, data_source, num_epochs): + self.dataset.set_training_data_src(data_source) + data_iter = self.dataset.create_iterator(num_epochs) + for _ in range(num_epochs): + for d in data_iter: + loss = self.model.train_step(d) + print(loss) + + def _test_predict(self, config, run_times=2): + self.initialize(config) + test_data = self.dataset._test_data_parse() + for i in range(run_times): + t1 = time.time() + result = self.predict(test_data) + t2 = time.time() + print("predict times : ", i, " cost : ", t2 - t1, " s") diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60aa617db1354d75de5bd63b21aa0a7dd1cd66e4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Potential energy""" + +from .potential import PotentialCell +from .forcefield import ForceFieldBase, ForceField +from .energy import * +from .bias import * + +__all__ = ['PotentialCell', 'ForceFieldBase', 'ForceField'] +__all__.extend(energy.__all__) +__all__.extend(bias.__all__) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb7aeac96e80e29bfd8ea101bf48b14d67d903a --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bais potential""" + +from .bias import Bias +from .oscillator import OscillatorBias +from .spherical import SphericalRestrict + +__all__ = ['Bias', 'OscillatorBias', 'SphericalRestrict'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/bias.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/bias.py new file mode 100644 index 0000000000000000000000000000000000000000..9183ebf085898e1feba397d43206beafd803f7dd --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/bias.py @@ -0,0 +1,123 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base cell for bais potential""" + +from mindspore import Tensor + +from ..potential import PotentialCell +from ...colvar import Colvar +from ...function.units import Units, global_units + + +class Bias(PotentialCell): + r""" + Basic cell for bias potential. + + Args: + colvar (Colvar): Collective variables. Default: None + multiple_walkers (bool): Whether to use multiple walkers. Default: False + length_unit (str): Length unit for position coordinates. Default: None + energy_unit (str): Energy unit. Default: None + units (Units): Units of length and energy. Default: global_units + use_pbc (bool): Whether to use periodic boundary condition. Default: None + + Returns: + potential (Tensor), Tensor of shape (B, 1). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + colvar: Colvar = None, + multiple_walkers: bool = False, + length_unit: str = None, + energy_unit: str = None, + units: Units = global_units, + use_pbc: bool = None, + ): + + super().__init__( + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + use_pbc=use_pbc, + ) + + if units is None: + self.units.set_length_unit(length_unit) + self.units.set_energy_unit(energy_unit) + else: + self.units = units + + self.colvar = colvar + self.multiple_walkers = multiple_walkers + + def update(self, coordinates: Tensor, pbc_box: Tensor = None): + """ + Update parameter of bias potential. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + pbc_box (Tensor, optional): Tensor of shape (B, D) or (1, D). Data type is float. + Box of periodic boundary condition. Default: None. + """ + #pylint: disable = unused-argument + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + pbc_box: Tensor = None + ): + r""" + Calculate bias potential. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. Default: None. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour atoms. Default: None + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. Default: None. + pbc_box (Tensor, optional): Tensor of shape (B, D) or (1, D). Data type is float. + Box of periodic boundary condition. Default: None. + + Returns: + potential (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + N: Maximum number of neighbour atoms. + D: Dimension of the simulation system. Usually is 3. + """ + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/oscillator.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/oscillator.py new file mode 100644 index 0000000000000000000000000000000000000000..d8de0dc183ba7207efd34266b5e13a3156171491 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/oscillator.py @@ -0,0 +1,66 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Harmonic oscillator module. +""" +import mindspore as ms +from mindspore import Tensor +from ..potential import PotentialCell + + +class OscillatorBias(PotentialCell): + """ + Add a restraint for heavy atoms in a molecule. + + Args: + old_crd(Tensor): The origin coordinates of all atoms. + k(float): The elasticity coefficient of all atoms, assuming to be the same. + nonh_mask(Tensor): A mask to distinguish H atoms and heavy atoms. + + Returns: + potential (Tensor). + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + old_crd, + k, + nonh_mask, + ): + super().__init__() + self.old_crd = Tensor(old_crd, ms.float32) + self.k = Tensor(k, ms.float32) + self.nonh_mask = Tensor(1 - nonh_mask, ms.int32) + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + pbc_box: Tensor = None + ): + shift = coordinate - self.old_crd + energy = 0.5 * self.k * shift ** 2 * self.nonh_mask + return energy.sum(-1).sum(1)[None, :] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/spherical.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/spherical.py new file mode 100644 index 0000000000000000000000000000000000000000..35fa97a7e464fda793fa2f6ac87a3a555d874ddc --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/bias/spherical.py @@ -0,0 +1,131 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base cell for bais potential""" + +import mindspore as ms +from mindspore import Tensor +from mindspore import nn +from mindspore.ops import functional as F + +from .bias import Bias +from ...function.units import Units, global_units, Length, Energy +from ...function import functions as func + + +class SphericalRestrict(Bias): + r""" + Basic cell for bias potential. + + .. Math:: + + V(R) = k * log(1 + exp((|R - R_0| - r_0) / \sigma)) + + Args: + radius (float): Radius of sphere (r_0). + center (Tensor): Coordinate of the center of sphere (R_0). Default: 0 + force_constant (float): Force constant of the bias potential(k). Default: Energy(500, 'kj/mol') + depth (float): Wall depth of the restriction (\sigma). Default: Length(0.01, 'nm') + length_unit (str): Length unit for position coordinates. Default: None + energy_unit (str): Energy unit. Default: None + units (Units): Units of length and energy. Default: global_units + use_pbc (bool): Whether to use periodic boundary condition. Default: None + + Returns: + potential (Tensor), Tensor of shape (B, 1). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + radius: float, + center: Tensor = 0, + force_constant: float = Energy(500, 'kj/mol'), + depth: float = Length(0.01, 'nm'), + length_unit: str = None, + energy_unit: str = None, + units: Units = global_units, + use_pbc: bool = None, + ): + + super().__init__( + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + use_pbc=use_pbc, + ) + + self.radius = Tensor(radius, ms.float32) + self.center = Tensor(center, ms.float32) + + if isinstance(force_constant, Energy): + force_constant = force_constant(self.units) + self.force_constant = Tensor(force_constant, ms.float32) + + if isinstance(depth, Length): + depth = depth(self.units) + self.depth = Tensor(depth, ms.float32) + + self.norm_last_dim = nn.Norm(axis=-1, keep_dims=False) + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + pbc_box: Tensor = None + ): + r""" + Calculate bias potential. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. Default: None + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour atoms. Default: None + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. Default: None + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + potential (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + N: Maximum number of neighbour atoms. + D: Dimension of the simulation system. Usually is 3. + """ + + # (B, A) <- (B, A, D) + distance = self.norm_last_dim(coordinate - self.center) + diff = distance - self.radius + bias = self.force_constant * F.log(1.0 + F.exp(diff/self.depth)) + + # (B, 1) <- (B, A) + return func.keepdim_sum(bias, -1) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10461ce4d29f83b63bab6bafd0604f7b64b6ee50 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Energy terms +""" + +from .energy import EnergyCell, NonbondEnergy +from .bond import BondEnergy +from .angle import AngleEnergy +from .dihedral import DihedralEnergy +from .coulomb import CoulombEnergy +from .lj import LennardJonesEnergy +from .pairs import NonbondPairwiseEnergy + +__all__ = ['EnergyCell', 'NonbondEnergy', 'BondEnergy', 'AngleEnergy', 'DihedralEnergy', + 'CoulombEnergy', 'LennardJonesEnergy', 'NonbondPairwiseEnergy'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/angle.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/angle.py new file mode 100644 index 0000000000000000000000000000000000000000..2d8197128cfd45160e9517257703c4c77d78e493 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/angle.py @@ -0,0 +1,184 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Angle energy""" + +import mindspore as ms +from mindspore import Tensor +from mindspore import Parameter +from mindspore.ops import functional as F + +from .energy import EnergyCell +from ...colvar import AtomAngles +from ...function import functions as func +from ...function.units import Units + + +class AngleEnergy(EnergyCell): + r""" + Energy term of bond angles. + + .. Math:: + + E_{angle}(\theta_{ijk}) = 1 / 2 \times k_{ijk}^\theta \times (\theta_{ijk} - \theta_{ijk}^0) ^ 2 + + Args: + index (Tensor): Tensor of shape (B, a, 3). Data type is int. + Atom index of bond angles. Default: None + force_constant (Tensor): Tensor of shape (1, a). Data type is float. + The harmonic force constants for angle :math:`(k^{\theta})`. Default: None + bond_angle (Tensor): Tensor of shape (1, a). Data type is float. + The equilibrium value of bond angle :math:`({\theta}^0)`. Default: None + parameters (dict): Force field parameters. Default: None + use_pbc (bool): Whether to use periodic boundary condition. Default: None + energy_unit (str): Energy unit. Default: 'kj/mol' + units (Units): Units of length and energy. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + a: Number of angles. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + index: Tensor = None, + force_constant: Tensor = None, + bond_angle: Tensor = None, + parameters: dict = None, + use_pbc: bool = None, + energy_unit: str = 'kj/mol', + units: Units = None, + ): + + super().__init__( + label='angle_energy', + output_dim=1, + use_pbc=use_pbc, + energy_unit=energy_unit, + units=units, + ) + + if parameters is not None: + length_unit = parameters.get('length_unit') + energy_unit = parameters.get('energy_unit') + self.units.set_units(length_unit, energy_unit) + + index = parameters.get('index') + force_constant = parameters.get('force_constant') + bond_angle = parameters.get('bond_angle') + + # (1,a,3) + index = Tensor(index, ms.int32) + if index.shape[-1] != 3: + raise ValueError('The last dimension of index in AngleEnergy must be 3 but got: ' + + str(index.shape[-1])) + if index.ndim == 2: + index = F.expand_dims(index, 0) + if index.ndim != 3: + raise ValueError('The rank of index must be 2 or 3 but got shape: '+str(index.shape)) + self.index = Parameter(index, name='angle_index', requires_grad=False) + + self.num_angles = index.shape[-2] + + # (1,a) + force_constant = Tensor(force_constant, ms.float32) + if force_constant.shape[-1] != self.num_angles: + raise ValueError('The last shape of force_constant ('+str(force_constant.shape[-1]) + + ') must be equal to num_angles ('+str(self.num_angles)+')!') + if force_constant.ndim == 1: + force_constant = F.expand_dims(force_constant, 0) + if force_constant.ndim > 2: + raise ValueError('The rank of force_constant cannot be larger than 2!') + self.force_constant = Parameter(force_constant, name='angle_force_constant') + + bond_angle = Tensor(bond_angle, ms.float32) + if bond_angle.shape[-1] != self.num_angles: + raise ValueError('The last shape of bond_angle ('+str(bond_angle.shape[-1]) + + ') must be equal to num_angles ('+str(self.num_angles)+')!') + if bond_angle.ndim == 1: + bond_angle = F.expand_dims(bond_angle, 0) + if bond_angle.ndim > 2: + raise ValueError('The rank of bond_angle cannot be larger than 2!') + self.bond_angle = Parameter(bond_angle, name='bond_angle') + + self.get_angle = AtomAngles(self.index, use_pbc=use_pbc) + + def set_pbc(self, use_pbc=None): + self.use_pbc = use_pbc + self.get_angle.set_pbc(use_pbc) + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + inv_neigh_dis: Tensor = None, + pbc_box: Tensor = None, + ): + r""" + Calculate energy term. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + # (B,M) + theta = self.get_angle(coordinate, pbc_box) + # (B,M) = (B,M) - (1,M) + dtheta = theta - self.bond_angle + dtheta2 = dtheta * dtheta + + # E_angle = 1/2 * k_\theta * (\theta-\theta_0)^2 + # (B,M) = (1,M) * (B,M) * k + energy = 0.5 * self.force_constant * dtheta2 + + # (B,1) <- (B,M) + return func.keepdim_sum(energy, -1) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/bond.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/bond.py new file mode 100644 index 0000000000000000000000000000000000000000..1691a887f77112dcb04dfc0827ea539cfa4109ce --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/bond.py @@ -0,0 +1,191 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bond energy""" + +import mindspore as ms +from mindspore import Tensor +from mindspore import Parameter +from mindspore.ops import functional as F + +from .energy import EnergyCell +from ...colvar import AtomDistances +from ...function import functions as func +from ...function.units import Units + + +class BondEnergy(EnergyCell): + r""" + Energy term of bond length. + + .. Math:: + + E_{bond}(b_{ij}) = 1 / 2 * k_{ij}^b * (b_{ij} - b_{ij}^0) ^ 2 + + Args: + index (Tensor): Tensor of shape (B, b, 2). Data type is int. + Atom index of bond. + force_constant (Tensor): Tensor of shape (1, b). Data type is float. + The harmonic force constants of bond length (k^b). + bond_length (Tensor): Tensor of shape (1, b). Data type is float. + The equilibrium value of bond length (b^0). + parameters (dict): Force field parameters. Default: None + use_pbc (bool): Whether to use periodic boundary condition. + length_unit (str): Length unit for position coordinates. Default: 'nm' + energy_unit (str): Energy unit. Default: 'kj/mol' + units (Units): Units of length and energy. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + b: Number of bonds. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + index: Tensor = None, + force_constant: Tensor = None, + bond_length: Tensor = None, + parameters: dict = None, + use_pbc: bool = None, + length_unit: str = 'nm', + energy_unit: str = 'kj/mol', + units: Units = None, + ): + + super().__init__( + label='bond_energy', + output_dim=1, + use_pbc=use_pbc, + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + ) + + if parameters is not None: + length_unit = parameters.get('length_unit') + energy_unit = parameters.get('energy_unit') + self.units.set_units(length_unit, energy_unit) + + index = parameters.get('index') + force_constant = parameters.get('force_constant') + bond_length = parameters.get('bond_length') + + # (B,b,2) + index = Tensor(index, ms.int32) + if index.shape[-1] != 2: + raise ValueError('The last dimension of index in BondEnergy must be 2 but got: ' + + str(index.shape[-1])) + if index.ndim == 2: + index = F.expand_dims(index, 0) + if index.ndim != 3: + raise ValueError('The rank of index must be 2 or 3 but got shape: '+str(index.shape)) + self.index = Parameter(index, name='bond_index', requires_grad=False) + + # (B,b) + self.get_bond_length = AtomDistances(self.index, use_pbc=use_pbc, length_unit=self.units) + + # b + self.num_bonds = index.shape[-2] + + # (B,b) + force_constant = Tensor(force_constant, ms.float32) + if force_constant.shape[-1] != self.num_bonds: + raise ValueError('The last shape of force_constant ('+str(force_constant.shape[-1]) + + ') must be equal to num_bonds ('+str(self.num_bonds)+')!') + if force_constant.ndim == 1: + force_constant = F.expand_dims(force_constant, 0) + if force_constant.ndim > 2: + raise ValueError('The rank of force_constant cannot be larger than 2!') + self.force_constant = Parameter(force_constant, name='bond_force_constant') + + bond_length = Tensor(bond_length, ms.float32) + if bond_length.shape[-1] != self.num_bonds: + raise ValueError('The last shape of bond_length ('+str(bond_length.shape[-1]) + + ') must be equal to num_bonds ('+str(self.num_bonds)+')!') + if bond_length.ndim == 1: + bond_length = F.expand_dims(bond_length, 0) + if bond_length.ndim > 2: + raise ValueError('The rank of bond_length cannot be larger than 2!') + self.bond_length = Parameter(bond_length, name='bond_length') + + def set_pbc(self, use_pbc=None): + self.use_pbc = use_pbc + self.get_bond_length.set_pbc(use_pbc) + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + inv_neigh_dis: Tensor = None, + pbc_box: Tensor = None, + ): + r""" + Calculate energy term. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + + # (B,b) + dis = self.get_bond_length(coordinate, pbc_box) * self.input_unit_scale + + # (B,b) = (B,b) - (B,b) + diff = dis - self.bond_length + # (B,b) + diff2 = F.square(diff) + + # (B,b) = (1,b) * (B,b) * k + energy = 0.5 * self.force_constant * diff2 + + # (B,1) <- (B,b) + return func.keepdim_sum(energy, -1) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/coulomb.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/coulomb.py new file mode 100644 index 0000000000000000000000000000000000000000..e1415f7fe5ebd4281f3debcc3e8ff33e75e45117 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/coulomb.py @@ -0,0 +1,621 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Electroinc interaction""" +from numpy import exp + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor, Parameter +from mindspore import ms_function +from mindspore import ops +from mindspore.nn import Cell +from mindspore.ops import functional as F + +from ...colvar import AtomDistances +from .energy import NonbondEnergy +from ...function import functions as func +from ...function.functions import gather_values +from ...function.units import Units + + +@ms_function +def coulomb_interaction(qi: Tensor, qj: Tensor, inv_dis: Tensor, mask: Tensor = None): + """calculate Coulomb interaction using Coulomb's law.""" + + # (B,A,N) = (B,A,1) * (B,A,N) + qiqj = qi * qj + + # (B,A,N) + energy = qiqj * inv_dis + + if mask is not None: + # (B,A,N) * (B,A,N) + energy *= mask + + # (B,A) + energy = F.reduce_sum(energy, -1) + # (B,1) + energy = func.keepdim_sum(energy, 1) * 0.5 + + return energy + + +class CoulombEnergy(NonbondEnergy): + r""" + Coulomb interaction. + + .. Math:: + + E_{ele}(r_{ij}) = \sum_{ij} k_{coulomb} \times q_i \times q_j / r_{ij} + + Args: + atom_charge (Tensor): Tensor of shape (B, A). Data type is float. + Atom charge. Default: None. + parameters (dict): Force field parameters. Default: None. + cutoff (float): Cutoff distance. Default: None. + use_pbc (bool, optional): Whether to use periodic boundary condition. Default: None. + use_pme (bool, optional): Whether to use particle mesh ewald condition. Default: None. + alpha (float): Alpha for DSF and PME coulomb interaction. + Default: 0.25. + nfft (Tensor): Parameter of FFT, required by PME. Default: None. + exclude_index (Tensor): Tensor of the exclude index, required by PME. Default: None. + length_unit (str): Length unit for position coordinates. Default: 'nm'. + energy_unit (str): Energy unit. Default: 'kj/mol'. + units (Units): Units of length and energy. Default: None. + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + atom_charge: Tensor = None, + parameters: dict = None, + cutoff: float = None, + use_pbc: bool = None, + use_pme: bool = False, + alpha: float = 0.25, + nfft: Tensor = None, + exclude_index: Tensor = None, + length_unit: str = 'nm', + energy_unit: str = 'kj/mol', + units: Units = None, + ): + + super().__init__( + label='coulomb_energy', + output_dim=1, + cutoff=cutoff, + use_pbc=use_pbc, + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + ) + + if parameters is not None: + length_unit = parameters.get('length_unit') + energy_unit = parameters.get('energy_unit') + self.units.set_units(length_unit, energy_unit) + + self.atom_charge = self.identity(atom_charge) + self.coulomb_const = Tensor(self.units.coulomb, ms.float32) + + if use_pme is None: + use_pme = use_pbc + self.use_pme = use_pme + if self.use_pme and (not self.use_pbc): + raise ValueError('PME cannot be used without periodic box conditions') + + self.pme_coulomb = None + self.dsf_coulomb = None + if self.use_pme: + self.pme_coulomb = ParticleMeshEwaldCoulomb(self.cutoff, alpha, nfft, exclude_index, self.units) + else: + self.dsf_coulomb = DampedShiftedForceCoulomb(self.cutoff, alpha) + + def set_cutoff(self, cutoff: Tensor): + """ + Set cutoff distance. + + Args: + cutoff (Tensor): Cutoff distance. Default: None. + """ + if cutoff is None: + if self.use_pbc: + raise ValueError('cutoff cannot be none when using periodic boundary condition') + self.cutoff = None + else: + self.cutoff = Tensor(cutoff, ms.float32) + if self.dsf_coulomb is not None: + self.dsf_coulomb.set_cutoff(cutoff) + if self.pme_coulomb is not None: + self.pme_coulomb.set_cutoff(cutoff) + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + inv_neigh_dis: Tensor = None, + pbc_box: Tensor = None, + ): + r""" + Calculate energy term. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + """ + + inv_neigh_dis *= self.inverse_input_scale + + # (B,A,1) + qi = F.expand_dims(self.atom_charge, -1) + # (B,A,N) + qj = gather_values(self.atom_charge, neighbour_index) + + if self.cutoff is None: + energy = coulomb_interaction(qi, qj, inv_neigh_dis, neighbour_mask) + else: + neighbour_distance *= self.input_unit_scale + if self.use_pme: + energy = self.pme_coulomb(coordinate, + qi, qj, neighbour_distance, + inv_neigh_dis, neighbour_mask, + pbc_box) + else: + energy = self.dsf_coulomb( + qi, qj, neighbour_distance, inv_neigh_dis, neighbour_mask) + + return energy * self.coulomb_const + + +class DampedShiftedForceCoulomb(Cell): + r"""Damped shifted force coulomb potential. + + Args: + + atom_charge (Tensor): Tensor of shape (B, A). Data type is float. + Atom charge. + + cutoff (float): Cutoff distance. Default: None + + alpha (float): Alpha. Default: 0.25 + + use_pbc (bool): Whether to use periodic boundary condition. Default: None + + length_unit (str): Length unit for position coordinates. Default: None + + energy_unit (str): Energy unit. Default: None + + units (Units): Units of length and energy. Default: None + + """ + + def __init__(self, + cutoff: float = None, + alpha: float = 0.25, + ): + + super().__init__() + + self.alpha = Parameter(Tensor(alpha, ms.float32), name='alpha', requires_grad=False) + + self.erfc = ops.Erfc() + self.f_shift = None + self.e_shift = None + if cutoff is not None: + self.set_cutoff(cutoff) + + def set_cutoff(self, cutoff: Tensor): + """set cutoff distance""" + self.cutoff = Tensor(cutoff, ms.float32) + cutoff2 = F.square(self.cutoff) + erfcc = self.erfc(self.alpha * self.cutoff) + erfcd = msnp.exp(-F.square(self.alpha) * cutoff2) + + self.f_shift = -(erfcc / cutoff2 + 2 / msnp.sqrt(msnp.pi) + * self.alpha * erfcd / self.cutoff) + self.e_shift = erfcc / self.cutoff - self.f_shift * self.cutoff + + def construct(self, + qi: Tensor, + qj: Tensor, + dis: Tensor, + inv_dis: Tensor, + mask: Tensor = None, + ): + r"""Calculate energy term. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor): Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + + # (B,A,N) = (B,A,1) * (B,A,N) + qiqj = qi*qj + energy = qiqj * inv_dis * (self.erfc(self.alpha * dis) - + dis * self.e_shift - F.square(dis) * self.f_shift) + + if mask is None: + mask = dis < self.cutoff + else: + mask = F.logical_and(mask, dis < self.cutoff) + + energy = msnp.where(mask, energy, 0.0) + + # (B,A) + energy = F.reduce_sum(energy, -1) + # (B,1) + energy = func.keepdim_sum(energy, 1) * 0.5 + + return energy + +#pylint: disable=unused-argument +class RFFT3D(Cell): + r"""rfft3d""" + def __init__(self, fftx, ffty, fftz, fftc, inverse): + Cell.__init__(self) + self.cast = ms.ops.Cast() + self.rfft3d = ms.ops.FFT3D() + self.irfft3d = ms.ops.IFFT3D() + self.inverse = inverse + if self.inverse: + self.norm = msnp.ones(fftc, dtype=ms.float32) * fftx * ffty * fftz + self.norm = 1 / self.norm + self.norm[1:-1] *= 2 + else: + self.norm = msnp.ones(fftc, dtype=ms.float32) * fftx * ffty * fftz + self.norm[1:-1] /= 2 + + def construct(self, x): + if self.inverse: + return self.irfft3d(x) + return self.rfft3d(x) + + def bprop(self, x, out, dout): + if self.inverse: + ans = self.rfft3d(dout) + else: + ans = self.irfft3d(dout) + return (ans,) + + +class ParticleMeshEwaldCoulomb(Cell): + r"""Particle mesh ewald algorithm for electronic interaction + + Args: + + atom_charge (Tensor): Tensor of shape (B, A). Data type is float. + Atom charge. + + cutoff (float): Cutoff distance. Default: None + + alpha (float): the parameter of the Gaussian charge. Default: 0.275106 + + nfft (Tensor): Tensor of FFT parameter. Default: None + + exclude_index (Tensor): Tensor of the exclude index. Default: None + + units (Units): Units of length and energy. Default: None + """ + + def __init__(self, + cutoff: float = None, + alpha: float = 0.275106, + nfft: Tensor = None, + exclude_index: Tensor = None, + units: Units = None): + + super().__init__() + + self.units = units + self.cutoff = cutoff + self.alpha = Tensor(0.275106, ms.float32) + self.erfc = ops.Erfc() + self.input_unit_scale = 1 + self.exclude_index = None + self.exclude_pairs = None + self.get_exclude_distance = None + self.nfft = None + self.fftx = None + self.ffty = None + self.fftz = None + self.fftc = None + self.fftx = None + self.ffty = None + self.fftz = None + self.b = None + self.rfft3d = None + self.irfft3d = None + self.set_nfft(nfft) + self.double_gradient = Double_Gradient() + #self.set_nfft([[4,4,4]]) + print(self.nfft, self.alpha) + self.cast = ms.ops.Cast() + + ma = [1.0 / 6.0, -0.5, 0.5, -1.0 / 6.0] + ma = Tensor([[ma[i], ma[j], ma[k]]for i in range(4) for j in range(4) for k in range(4)], ms.float32) + self.ma = ma.reshape(1, 1, 64, 3) + mb = [0, 0.5, -1, 0.5] + mb = Tensor([[mb[i], mb[j], mb[k]]for i in range(4) for j in range(4) for k in range(4)], ms.float32) + self.mb = mb.reshape(1, 1, 64, 3) + mc = [0, 0.5, 0, -0.5] + mc = Tensor([[mc[i], mc[j], mc[k]]for i in range(4) for j in range(4) for k in range(4)], ms.float32) + self.mc = mc.reshape(1, 1, 64, 3) + md = [0, 1.0 / 6.0, 4.0 / 6.0, 1.0 / 6.0] + md = Tensor([[md[i], md[j], md[k]]for i in range(4) for j in range(4) for k in range(4)], ms.float32) + self.md = md.reshape(1, 1, 64, 3) + self.base_grid = Tensor([[i, j, k] for i in range(4) for j in range(4) for k in range(4)], + ms.int32).reshape(1, 1, 64, 3) + self.batch_constant = msnp.ones((exclude_index.shape[0], exclude_index.shape[1], 64, 1), ms.int32) + self.batch_constant *= msnp.arange(0, exclude_index.shape[0]).reshape(-1, 1, 1, 1) + self.set_exclude_index(exclude_index) + if units: + self.set_input_unit(units) + if alpha: + self.set_alpha(alpha) + + @staticmethod + def _m(u, n): + """get factor m""" + if n == 2: + if u > 2 or u < 0: + return 0 + return 1 - abs(u - 1) + self = ParticleMeshEwaldCoulomb._m + return u / (n - 1) * self(u, n - 1) + (n - u) / (n - 1) * self(u - 1, n - 1) + + @staticmethod + def _b(k, fftn, order=4): + """get factor b""" + tempc2 = complex(0, 0) + tempc = complex(0, 2 * (order - 1) * msnp.pi * k / fftn) + res = exp(tempc) + for kk in range(order - 1): + tempc = complex(0, 2 * msnp.pi * k / fftn * kk) + tempc = exp(tempc) + tempf = ParticleMeshEwaldCoulomb._m(kk + 1, order) + tempc2 += tempf * tempc + res = res / tempc2 + return abs(res) * abs(res) + + def set_input_unit(self, units: Units): + """set the length unit for the input coordinates""" + if units is None: + self.input_unit_scale = 1 + elif isinstance(units, Units): + self.input_unit_scale = Tensor( + self.units.convert_length_from(units), ms.float32) + else: + raise TypeError('Unsupported type: '+str(type(units))) + return self + + def set_cutoff(self, cutoff: Tensor): + """set cutoff distance""" + self.cutoff = Tensor(cutoff, ms.float32) + + def set_alpha(self, alpha: Tensor): + """set the parameter beta""" + self.alpha = Tensor(alpha, ms.float32) + + def set_exclude_index(self, exclude_index: Tensor): + """set exclude index""" + if exclude_index is None: + self.exclude_index = None + else: + if exclude_index.ndim != 3: + raise ValueError('The rank of exclude index must be 3.') + if exclude_index.shape[2] == 0: + self.exclude_index = None + else: + self.exclude_index = Tensor(exclude_index, ms.int32) + if self.exclude_index is not None: + t = [] + for batch in self.exclude_index: + t.append([]) + for i, ex in enumerate(batch): + for ex_atom in ex: + if i < ex_atom < self.exclude_index.shape[1]: + t[-1].append([i, ex_atom]) + self.exclude_pairs = msnp.array(t) + self.get_exclude_distance = AtomDistances(self.exclude_pairs, use_pbc=True, length_unit=self.units) + + def set_nfft(self, nfft: Tensor): + """set nfft""" + self.nfft = Tensor(nfft, ms.int32).reshape((-1, 1, 3)) + self.fftx = int(self.nfft[0][0][0]) + self.ffty = int(self.nfft[0][0][1]) + self.fftz = int(self.nfft[0][0][2]) + if self.fftx % 4 != 0 or self.ffty % 4 != 0 or self.fftz % 4 != 0: + raise ValueError("The FFT grid number for PME must be a multiple of 4") + self.fftc = self.fftz // 2 + 1 + self.ffkx = msnp.arange(self.fftx) + self.ffkx = msnp.where(self.ffkx > self.fftx / 2, self.fftx - self.ffkx, self.ffkx).reshape(-1, 1, 1) + self.ffky = msnp.arange(self.ffty) + self.ffky = msnp.where(self.ffky > self.ffty / 2, self.ffty - self.ffky, self.ffky).reshape(1, -1, 1) + self.ffkz = msnp.arange(self.fftc).reshape(1, 1, -1) + + bx = msnp.array([self._b(i, self.fftx) for i in range(self.fftx)]) + by = msnp.array([self._b(i, self.ffty) for i in range(self.ffty)]) + bz = msnp.array([self._b(i, self.fftz) for i in range(self.fftc)]) + + self.b = bx.reshape(-1, 1, 1) * by.reshape(1, -1, 1) * bz.reshape(1, 1, -1) + self.rfft3d = RFFT3D(self.fftx, self.ffty, self.fftz, self.fftc, inverse=False) + self.irfft3d = RFFT3D(self.fftx, self.ffty, self.fftz, self.fftc, inverse=True) + + def calculate_direct_energy(self, + qi: Tensor, + qj: Tensor, + dis: Tensor, + inv_dis: Tensor, + mask: Tensor = None): + """Calculate the direct energy term.""" + # (B,A,N) = (B,A,1) * (B,A,N) + qiqj = qi*qj + energy = qiqj * inv_dis * (self.erfc(self.alpha * dis)) + + if mask is None: + mask = dis < self.cutoff + else: + mask = F.logical_and(mask, dis < self.cutoff) + + energy = msnp.where(mask, energy, 0.0) + + # (B,A) + energy = F.reduce_sum(energy, -1) + # (B,1) + energy = func.keepdim_sum(energy, 1) * 0.5 + + return energy + + def calculate_self_energy(self, qi: Tensor, pbc_box: Tensor): + """Calculate the direct energy term.""" + # (B,A,1) = (B,A,1) * (B,A,1) + qiqi = qi * qi + + # (B,1) + qiqi_sum = F.reduce_sum(qiqi, 1) + qi_sum = F.reduce_sum(qi, 1) + + energy = -self.alpha / msnp.sqrt(msnp.pi) * qiqi_sum + energy -= qi_sum * 0.5 * msnp.pi / (self.alpha * self.alpha * F.reduce_prod(pbc_box, 1)) + return energy + + def calculate_exclude_energy(self, coordinate: Tensor, qi: Tensor, pbc_box: Tensor): + """Calculate the excluded correction energy term.""" + if self.exclude_index is not None: + # (B,b) + dis = self.get_exclude_distance(coordinate, pbc_box) * self.input_unit_scale + # (B,A) <- (B,A,1): + qi = F.reshape(qi, (qi.shape[0], -1)) + # (B,b,2) <- (B,A): + qi = gather_values(qi, self.exclude_pairs) + # (B,b) <- (B,b,2): + qiqj = F.reduce_prod(qi, -1) + energy = -qiqj * F.erf(self.alpha * dis) / dis + energy = func.keepdim_sum(energy, -1) + return energy + return msnp.zeros((qi.shape[0], 1), ms.float32) + + def calculate_reciprocal_energy(self, coordinate: Tensor, qi: Tensor, pbc_box: Tensor): + """Calculate the reciprocal energy term.""" + # the batch dimension in the following part is ignored due to the limitation of the operator FFT3D + # (B,A,3) <- (B,A,3) / (B,1,3) * (B,1,3): + pbc_box = pbc_box.reshape((-1, 1, 3)) + frac = coordinate / F.stop_gradient(pbc_box) % 1.0 * self.nfft + grid = self.cast(frac, ms.int32) + frac = frac - F.floor(frac) + # (B,A,64,3) <- (B,A,1,3) + (1,1,64,3): + neibor_grids = F.expand_dims(grid, 2) - self.base_grid + neibor_grids %= F.expand_dims(self.nfft, 2) + # (B,A,64,3) <- (B,A,1,3) * (1,1,64,3) + frac = F.expand_dims(frac, 2) + neibor_q = frac * frac * frac * self.ma + frac * frac * self.mb + frac * self.mc + self.md + # (B,A,64) <- (B,A,1) * reduce (B,A,64,3) + neibor_q = qi * F.reduce_prod(neibor_q, -1) + # (B,A,64,4) <- concat (B,A,64,1) (B,A,64,3): + neibor_grids = F.concat((self.batch_constant, neibor_grids), -1) + # (B, fftx, ffty, fftz): + q_matrix = msnp.zeros([1, self.fftx, self.ffty, self.fftz], ms.float32) + q_matrix = F.tensor_scatter_add(q_matrix, neibor_grids.reshape(-1, 4), neibor_q.reshape(-1)) + + mprefactor = msnp.pi * msnp.pi / -self.alpha / self.alpha + # (fftx, ffty, fftc) = (fftx, 1, 1) + (1, ffty, 1) + (1, 1, fftc) + msq = self.ffkx * self.ffkx / pbc_box[0][0][0] / pbc_box[0][0][0] + \ + self.ffky * self.ffky / pbc_box[0][0][1] / pbc_box[0][0][1] + \ + self.ffkz * self.ffkz / pbc_box[0][0][2] / pbc_box[0][0][2] + msq[0][0][0] = 1 + bc = 1.0 / msnp.pi / msq * msnp.exp(mprefactor * msq) / F.reduce_prod(pbc_box, -1)[0] + bc[0][0][0] = 0 + bc *= self.b + fq = self.rfft3d(q_matrix.reshape(self.fftx, self.ffty, self.fftz)) + bcfq = bc * fq + fbcfq = self.irfft3d(bcfq) + fbcfq = F.expand_dims(fbcfq, 0) + energy = q_matrix * fbcfq + energy = 0.5 * F.reduce_sum(energy, (-1, -2, -3)) + energy = energy.reshape(-1, 1) + + return energy + + def construct(self, + coordinate: Tensor, + qi: Tensor, + qj: Tensor, + dis: Tensor, + inv_dis: Tensor, + mask: Tensor = None, + pbc_box: Tensor = None): + """Calculate energy term.""" + + direct_energy = self.calculate_direct_energy(qi, qj, dis, inv_dis, mask) + self_energy = self.calculate_self_energy(qi, pbc_box) + exclude_energy = self.calculate_exclude_energy(coordinate, qi, pbc_box) + reciprocal_energy = self.calculate_reciprocal_energy(coordinate, qi, pbc_box) + return direct_energy + self_energy + exclude_energy + reciprocal_energy diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/dihedral.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/dihedral.py new file mode 100644 index 0000000000000000000000000000000000000000..c22199261179c8c706cbdbe66f07dfe8136a9e76 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/dihedral.py @@ -0,0 +1,203 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Torsion energy""" + +import mindspore as ms +from mindspore import Tensor +from mindspore.ops import functional as F +from mindspore import Parameter + +from .energy import EnergyCell +from ...colvar import AtomTorsions +from ...function import functions as func +from ...function.units import Units + + +class DihedralEnergy(EnergyCell): + r""" + Energy term of dihedral (torsion) angles. + + .. Math:: + + E_{dihedral}(\omega) = \sum_n 1 / 2 \times V_n \times [1 - cos(n \times \omega - {\gamma}_n)] + + Args: + index (Tensor): Tensor of shape (B, d, 4) or (1, d, 4). Data type is int. + Atom index of dihedral angles. + force_constant (Tensor): Tensor of shape (B, d) or (1, d). Data type is float. + The harmonic force constants of bond torsional angle (V_n). + periodicity (Tensor): Tensor of shape (B, d) or (1, d). Data type is int. + The periodicity of the torsional barrier (n). + phase (Tensor): Tensor of shape (B, d) or (1, d). Data type is float. + The phase shift in the torsional function ({\gamma}_n). + parameters (dict): Force field parameters. Default: None + use_pbc (bool): Whether to use periodic boundary condition. Default: None + energy_unit (str): Energy unit. Default: 'kj/mol' + units (Units): Units of length and energy. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + d: Number of dihedral angles. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + index: Tensor = None, + force_constant: Tensor = None, + periodicity: Tensor = None, + phase: Tensor = None, + parameters: dict = None, + use_pbc: bool = None, + energy_unit: str = 'kj/mol', + units: Units = None, + ): + + super().__init__( + label='dihedral_energy', + output_dim=1, + use_pbc=use_pbc, + energy_unit=energy_unit, + units=units, + ) + + if parameters is not None: + energy_unit = parameters.get('energy_unit') + self.units.set_energy_unit(energy_unit) + + index = parameters.get('index') + force_constant = parameters.get('force_constant') + periodicity = parameters.get('periodicity') + phase = parameters.get('phase') + + # (1,d,4) + index = Tensor(index, ms.int32) + if index.shape[-1] != 4: + raise ValueError('The last dimension of index in DihedralEnergy must be 2 but got: ' + + str(index.shape[-1])) + if index.ndim == 2: + index = F.expand_dims(index, 0) + if index.ndim != 3: + raise ValueError( + 'The rank of index must be 2 or 3 but got shape: '+str(index.shape)) + self.index = Parameter(index, name='dihedral_index', requires_grad=False) + + # (1,d) + self.get_torsion = AtomTorsions(self.index, use_pbc=use_pbc) + + # d + self.num_torsions = index.shape[-2] + + # (1,d) + force_constant = Tensor(force_constant, ms.float32) + if force_constant.shape[-1] != self.num_torsions: + raise ValueError('The last shape of force_constant ('+str(force_constant.shape[-1]) + + ') must be equal to num_torsions ('+str(self.num_torsions)+')!') + if force_constant.ndim == 1: + force_constant = F.expand_dims(force_constant, 0) + if force_constant.ndim > 2: + raise ValueError('The rank of force_constant cannot be larger than 2!') + self.force_constant = Parameter(force_constant, name='dihedral_force_constant') + + periodicity = Tensor(periodicity, ms.int32) + if periodicity.shape[-1] != self.num_torsions: + raise ValueError('The last shape of periodicity ('+str(periodicity.shape[-1]) + + ') must be equal to num_torsions ('+str(self.num_torsions)+')!') + if periodicity.ndim == 1: + periodicity = F.expand_dims(periodicity, 0) + if periodicity.ndim > 2: + raise ValueError('The rank of periodicity cannot be larger than 2!') + self.periodicity = Parameter(periodicity, name='periodicity') + + phase = Tensor(phase, ms.float32) + if phase.shape[-1] != self.num_torsions: + raise ValueError('The last shape of phase ('+str(phase.shape[-1]) + + ') must be equal to num_torsions ('+str(self.num_torsions)+')!') + if phase.ndim == 1: + phase = F.expand_dims(phase, 0) + if phase.ndim > 2: + raise ValueError('The rank of phase cannot be larger than 2!') + self.dihedral_phase = Parameter(phase, name='phase') + + def set_pbc(self, use_pbc=None): + self.use_pbc = use_pbc + self.get_torsion.set_pbc(use_pbc) + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + inv_neigh_dis: Tensor = None, + pbc_box: Tensor = None, + ): + r""" + Calculate energy term. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + # (B,M) + phi = self.get_torsion(coordinate, pbc_box) + + # (B,M) = (1,M) * (B,M) + nphi = self.periodicity * phi + + # (B,M) + cosphi = F.cos(nphi - self.dihedral_phase) + 1 + + # (B,M) = (1,M) + (B,M) + energy = 0.5 * self.force_constant * cosphi + + # (B,1) <- (B,M) + energy = func.keepdim_sum(energy, -1) + + return energy diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/energy.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/energy.py new file mode 100644 index 0000000000000000000000000000000000000000..d1669ca52e597e533be0f684a41f32b4db5a326e --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/energy.py @@ -0,0 +1,287 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base energy cell""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore import ops +from mindspore.nn import Cell + +from ...function import functions as func +from ...function.units import Units + + +class EnergyCell(Cell): + r""" + Basic cell for energy term. + + Args: + label (str): Label (name) of energy. + output_dim (int): Output dimension. Default: 1 + length_unit (str): Length unit for position coordinates. Default: 'nm' + energy_unit (str): Energy unit. Default: 'kj/mol' + units (Units): Units of length and energy. Default: None + use_pbc (bool): Whether to use periodic boundary condition. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + label: str, + output_dim: int = 1, + length_unit: str = 'nm', + energy_unit: str = 'kj/mol', + units: Units = None, + use_pbc: bool = None, + ): + + super().__init__() + + self.label = label + self.output_dim = func.get_integer(output_dim) + + self.use_pbc = use_pbc + + if units is None: + self.units = Units(length_unit, energy_unit) + else: + if not isinstance(units, Units): + raise TypeError( + 'The type of units must be "Unit" but get type: '+str(type(units))) + self.units = units + + self.gather_values = func.gather_values + self.gather_vectors = func.gather_vectors + + self.input_unit_scale = 1 + self.cutoff = None + self.identity = ops.Identity() + + def set_input_unit(self, units: Units): + """ + Set the length unit for the input coordinates. + + Args: + units (Units): Units of length and energy. Default: None. + """ + if units is None: + self.input_unit_scale = 1 + elif isinstance(units, Units): + self.input_unit_scale = Tensor( + self.units.convert_length_from(units), ms.float32) + else: + raise TypeError('Unsupported type: '+str(type(units))) + return self + + def set_cutoff(self, cutoff: float): + """ + Set cutoff distances. + + Args: + cutoff (float): Cutoff distance. Default: None. + """ + if cutoff is None: + self.cutoff = None + else: + self.cutoff = Tensor(cutoff, ms.float32) + return self + + def set_pbc(self, use_pbc: bool = None): + """ + Set whether to use periodic boundary condition. + + Args: + use_pbc (bool, optional): Whether to use periodic boundary condition. Default: None. + """ + self.use_pbc = use_pbc + return self + + def convert_energy_from(self, unit: str) -> float: + """ + Convert energy from outside unit to inside unit. + + Args: + unit (str): Units of length and energy. Examples: 'nm', 'kj/mol'. + + Returns: + float, energy from outside unit to inside unit. + """ + return self.units.convert_energy_from(unit) + + def convert_energy_to(self, unit: str) -> float: + """ + Convert energy from inside unit to outside unit. + + Args: + unit (str): Units of length and energy. Examples: 'nm', 'kj/mol'. + + Returns: + float, energy from inside unit to outside unit. + """ + return self.units.convert_energy_to(unit) + + @property + def length_unit(self) -> float: + return self.units.length_unit + + @property + def energy_unit(self) -> float: + return self.units.energy_unit + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + inv_neigh_dis: Tensor = None, + pbc_box: Tensor = None + ): + r""" + Calculate energy term. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. Default: None + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. Default: None + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. Default: None + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. Default: None + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. Default: None + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + raise NotImplementedError + + +class NonbondEnergy(EnergyCell): + r""" + Basic cell for non-bonded energy term + + Args: + label (str): Label (name) of energy. + output_dim (int): Dimension of the output. Default: 1 + cutoff (float): cutoff distance. Default: None + length_unit (str): Length unit for position coordinates. Default: None + energy_unit (str): Energy unit. Default: None + use_pbc (bool): Whether to use periodic boundary condition. Default: None + units (Units): Units of length and energy. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + label: str, + output_dim: int = 1, + cutoff: float = None, + length_unit: str = 'nm', + energy_unit: str = 'kj/mol', + use_pbc: bool = None, + units: Units = None, + ): + + super().__init__( + label=label, + output_dim=output_dim, + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + use_pbc=use_pbc, + ) + + self.cutoff = None + if cutoff is not None: + self.cutoff = Tensor(cutoff, ms.float32) + + self.inverse_input_scale = 1 + + def set_input_unit(self, units: Units): + """ + Set the length unit for the input coordinates. + + Args: + units (Units): Units of length and energy. Default: None. + """ + super().set_input_unit(units) + self.inverse_input_scale = msnp.reciprocal(self.input_unit_scale) + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + inv_neigh_dis: Tensor = None, + pbc_box: Tensor = None, + ): + r"""Calculate energy term + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/lj.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/lj.py new file mode 100644 index 0000000000000000000000000000000000000000..eed64a166c37fcb2417e3bc4cc3dd46229757b0b --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/lj.py @@ -0,0 +1,251 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Lennard-Jones potential""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor, Parameter +from mindspore.ops import functional as F + +from .energy import NonbondEnergy +from ...function import functions as func +from ...function.functions import gather_values +from ...function.units import Units + + +class LennardJonesEnergy(NonbondEnergy): + r""" + Lennard-Jones potential + + .. Math:: + + E_{lj}(r_{ij}) = 4 * \epsilon_{ij} * [(\sigma_{ij} / r_{ij}) ^ {12} - (\sigma_{ij} / r_{ij}) ^ 6] + + \epsilon_{ij} = \sqrt(\epsilon_i * \epsilon_j) + + \sigma_{ij} = 1 / 2 * (\sigma_i + \sigma_j) + + ... + + Args: + epsilon (Tensor): Tensor of shape (B, A). Data type is float. + Parameter \epsilon for LJ potential. Default: None + sigma (Tensor): Tensor of shape (B, A). Data type is float. + Parameter \sigma in LJ potential. Default: None + mean_c6 (Tensor): Tensor of shape (B, A). Data type is float. + Average dispersion () of the system used for + long range correction of dispersion interaction. Default: 0 + parameters (dict): Force field parameters. Default: None + cutoff (float): Cutoff distance. Default: None + use_pbc (bool): Whether to use periodic boundary condition. Default: None + length_unit (str): Length unit for position coordinates. Default: 'nm' + energy_unit (str): Energy unit. Default: 'kj/mol' + units (Units): Units of length and energy. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + N: Maximum number of neighbour atoms. + D: Dimension of the simulation system. Usually is 3. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + epsilon: Tensor = None, + sigma: Tensor = None, + mean_c6: Tensor = 0, + parameters: dict = None, + cutoff: float = None, + use_pbc: bool = None, + length_unit: str = 'nm', + energy_unit: str = 'kj/mol', + units: Units = None, + ): + + super().__init__( + label='vdw_energy', + output_dim=1, + cutoff=cutoff, + use_pbc=use_pbc, + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + ) + + if parameters is not None: + length_unit = parameters.get('length_unit') + energy_unit = parameters.get('energy_unit') + self.units.set_units(length_unit, energy_unit) + + epsilon = parameters.get('epsilon') + sigma = parameters.get('sigma') + mean_c6 = parameters.get('mean_c6') + + sigma = Tensor(sigma, ms.float32) + epsilon = Tensor(epsilon, ms.float32) + + if sigma.shape[-1] != epsilon.shape[-1]: + raise ValueError('the last dimension of sigma'+str(sigma.shape[-1]) + + 'must be equal to the last dimension of epsilon ('+str(epsilon.shape[-1])+')!') + + self.num_atoms = sigma.shape[-1] + + if sigma.ndim == 1: + sigma = F.expand_dims(sigma, 0) + if sigma.ndim > 2: + raise ValueError('The rank of sigma cannot be larger than 2!') + self.sigma = Parameter(sigma, name='sigma') + + if epsilon.ndim == 1: + epsilon = F.expand_dims(epsilon, 0) + if epsilon.ndim > 2: + raise ValueError('The rank of epsilon cannot be larger than 2!') + self.epsilon = Parameter(epsilon, name='epsilon') + + self.mean_c6 = None + if mean_c6 is not None: + mean_c6 = Tensor(mean_c6, ms.float32) + if mean_c6.ndim == 0: + mean_c6 = mean_c6.reshape(1, 1) + elif mean_c6.ndim == 1: + mean_c6 = F.expand_dims(mean_c6, 0) + elif mean_c6.ndim > 2: + raise ValueError('The rank of mean_c6 cannot be larger than 2!') + self.mean_c6 = Parameter(Tensor(mean_c6, ms.float32), name='average_dispersion', requires_grad=False) + + self.disp_corr = self._calc_disp_corr() + + def set_cutoff(self, cutoff: float): + """ + Set cutoff distance. + + Args: + cutoff (float): Cutoff distance. Default: None. + """ + super().set_cutoff(cutoff) + self.disp_corr = self._calc_disp_corr() + return self + + def _calc_disp_corr(self) -> Tensor: + """ + calculate the long range correct factor for dispersion + + Returns: + Tensor, the long range correct factor for dispersion. + """ + + if self.cutoff is None: + return 0 + return -2.0 / 3.0 * msnp.pi * self.num_atoms**2 / msnp.power(self.cutoff, 3) + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + inv_neigh_dis: Tensor = None, + pbc_box: Tensor = None, + ): + r""" + Calculate energy term + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + + inv_neigh_dis *= self.inverse_input_scale + + epsilon = self.identity(self.epsilon) + sigma = self.identity(self.sigma) + + # (B,A,1) + eps_i = F.expand_dims(epsilon, -1) + # (B,A,N) + eps_j = gather_values(epsilon, neighbour_index) + # (B,A,N) = (B,A,1) * (B,A,N) + eps_ij = F.sqrt(eps_i * eps_j) + + # (B,A,1) + sigma_i = F.expand_dims(sigma, -1) + # (B,A,N) + sigma_j = gather_values(sigma, neighbour_index) + # (B,A,N) = (B,A,1) * (B,A,N) + sigma_ij = (sigma_i + sigma_j) * 0.5 + + # \sigma_ij / r_ij + sigma_over_rij = sigma_ij * inv_neigh_dis + # (\sigma_ij / r_ij) ^ 6 + sigma_over_rij_6 = F.pows(sigma_over_rij, 6) + + # 4 * \epsilon * (\sigma_ij / r_ij) ^ 6 + ene_bcoeff = 4 * eps_ij * sigma_over_rij_6 + # 4 * \epsilon * (\sigma_ij / r_ij) ^ 12 + ene_acoeff = ene_bcoeff * sigma_over_rij_6 + + # (B,A,N) + energy = ene_acoeff - ene_bcoeff + + # (B,A) + energy = F.reduce_sum(energy, -1) + # (B,1) + energy = func.keepdim_sum(energy, -1) * 0.5 + + if self.cutoff is not None and pbc_box is not None: + # (B,1) <- (B,D) + volume = func.keepdim_prod(pbc_box, -1) + # E_corr = -2 / 3 * pi * N * \rho * * r_c^-3 + # = -2 / 3 * pi * N * (N / V) * * r_c^-3 + # = -2 / 3 * pi * N^2 * / V + # = k_corr * / V + ene_corr = self.disp_corr * self.mean_c6 * msnp.reciprocal(volume) + energy += ene_corr + + return energy diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/pairs.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..165b5c668a7b3b05f6134a5fd5a0ba47d733b818 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/energy/pairs.py @@ -0,0 +1,303 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Non-bonded pairwise energy""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore import Parameter +from mindspore import ops +from mindspore.ops import functional as F + +from .energy import EnergyCell +from ...colvar import AtomDistances +from ...function.units import Units +from ...function.functions import get_integer, keepdim_sum + + +class NonbondPairwiseEnergy(EnergyCell): + r""" + Energy of non-bonded atom paris. + + .. math:: + + E_{pairs}(r_{ij}) = A_{ij}^p \cdot E_r(r_{ij}) + B_{ij}^p \cdot E_{r6}(r_{ij}) + C_{ij}^p \cdot E_{r12}(r_{ij}) + = A_{ij}^p \cdot k_{coulomb} \cdot q_i \cdot q_j / r_{ij} - + B_{ij}^p \cdot 4 \cdot \epsilon_{ij} \cdot (\sigma_{ij} / r_{ij}) ^ 6 + + C_{ij}^p \cdot 4 \cdot \epsilon_{ij} \cdot (\sigma_{ij} / r_{ij}) ^ {12} + + Args: + index (Tensor): Tensor of shape (B, p, 2). Data type is int. + Atom index of dihedral angles. + qiqj (Tensor): Tensor of shape (B, p). Data type is float. + Products of charges of non-bonded atom pairs. + epsilon_ij (Tensor): Tensor of shape (B, p). Data type is float. + \epsilon of non-bonded atom pairs. + sigma_ij (Tensor): Tensor of shape (B, p). Data type is float. + \sigma of non-bonded atom pairs. + r_scale (Tensor): Tensor of shape (1, p). Data type is float. + Scaling constant for r^-1 terms (A^p) in non-bond interaction. + r6_scale (Tensor): Tensor of shape (1, p). Data type is float. + Scaling constant for r^-6 terms (B^p) in non-bond interaction. + r12_scale (Tensor): Tensor of shape (1, p). Data type is float. + Scaling constant for r^-12 terms (C^p) in non-bond interaction. + parameters (dict): Force field parameters. Default: None. + cutoff (float): Cutoff distance. Default: None. + use_pbc (bool, optional): Whether to use periodic boundary condition. + If this is None, that means do not use periodic boundary condition. + Default: None. + length_unit (str): Length unit for position coordinates. Default: None. + energy_unit (str): Energy unit. Default: None. + units (Units): Units of length and energy. Default: None. + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + p: Number of non-bonded atom pairs. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + index: Tensor = None, + qiqj: Tensor = None, + epsilon_ij: Tensor = None, + sigma_ij: Tensor = None, + r_scale: Tensor = None, + r6_scale: Tensor = None, + r12_scale: Tensor = None, + parameters: dict = None, + cutoff: float = None, + use_pbc: bool = None, + length_unit: str = 'nm', + energy_unit: str = 'kj/mol', + units: Units = None, + ): + + super().__init__( + label='nb_pairs_energy', + output_dim=1, + use_pbc=use_pbc, + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + ) + + if parameters is not None: + length_unit = parameters.get('length_unit') + energy_unit = parameters.get('energy_unit') + self.units.set_units(length_unit, energy_unit) + + index = parameters.get('index') + qiqj = parameters.get('qiqj') + epsilon_ij = parameters.get('epsilon_ij') + sigma_ij = parameters.get('sigma_ij') + r_scale = parameters.get('r_scale') + r6_scale = parameters.get('r6_scale') + r12_scale = parameters.get('r12_scale') + + # (1,p,2) + index = Tensor(index, ms.int32) + if index.shape[-1] != 2: + raise ValueError('The last dimension of index in NonbondPairwiseEnergy must be 2 but got: ' + + str(index.shape[-1])) + if index.ndim == 2: + index = F.expand_dims(index, 0) + if index.ndim != 3: + raise ValueError('The rank of index must be 2 or 3 but got shape: '+str(index.shape)) + self.index = Parameter(index, name='pairs_index', requires_grad=False) + + self.num_pairs = index.shape[-2] + + qiqj = Tensor(qiqj, ms.float32) + if qiqj.shape[-1] != self.num_pairs: + raise ValueError('The last dimension of qiqj ('+str(qiqj.shape[-1]) + + ') must be equal to the number of non-bonded atom pairs('+str(self.num_pairs)+')!') + if qiqj.ndim == 1: + qiqj = F.expand_dims(qiqj, 0) + if qiqj.ndim > 2: + raise ValueError('The rank of qiqj cannot be larger than 2!') + self.qiqj = Parameter(qiqj, name='qiqj', requires_grad=False) + + epsilon_ij = Tensor(epsilon_ij, ms.float32) + if epsilon_ij.shape[-1] != self.num_pairs: + raise ValueError('The last dimension of epsilon_ij ('+str(epsilon_ij.shape[-1]) + + ') must be equal to the number of non-bonded atom pairs('+str(self.num_pairs)+')!') + if epsilon_ij.ndim == 1: + epsilon_ij = F.expand_dims(epsilon_ij, 0) + if epsilon_ij.ndim > 2: + raise ValueError('The rank of epsilon_ij cannot be larger than 2!') + self.epsilon_ij = Parameter(epsilon_ij, name='epsilon_ij', requires_grad=False) + + sigma_ij = Tensor(sigma_ij, ms.float32) + if sigma_ij.shape[-1] != self.num_pairs: + raise ValueError('The last dimension of sigma_ij ('+str(sigma_ij.shape[-1]) + + ') must be equal to the number of non-bonded atom pairs('+str(self.num_pairs)+')!') + if sigma_ij.ndim == 1: + sigma_ij = F.expand_dims(sigma_ij, 0) + if sigma_ij.ndim > 2: + raise ValueError('The rank of sigma_ij cannot be larger than 2!') + self.sigma_ij = Parameter(sigma_ij, name='sigma_ij', requires_grad=False) + + r_scale = Tensor(r_scale, ms.float32) + if r_scale.ndim == 0: + r_scale = r_scale.reshape(1, 1) + elif r_scale.ndim == 1: + r_scale = F.expand_dims(r_scale, 0) + elif r_scale.ndim > 2: + raise ValueError('The rank of r_scale cannot be larger than 2!') + if r_scale.shape[-1] != self.num_pairs and r_scale.shape[-1] != 1: + raise ValueError('The last dimension of r_scale ('+str(r_scale.shape[-1]) + + ') must be equal to 1 or the number of non-bonded atom pairs('+str(self.num_pairs)+')!') + self.r_scale = Parameter(r_scale, name='r_scale_factor') + + r6_scale = Tensor(r6_scale, ms.float32) + if r6_scale.ndim == 0: + r6_scale = r6_scale.reshape(1, 1) + elif r6_scale.ndim == 1: + r6_scale = F.expand_dims(r6_scale, 0) + elif r6_scale.ndim > 2: + raise ValueError('The rank of r6_scale cannot be larger than 2!') + if r6_scale.shape[-1] != self.num_pairs and r6_scale.shape[-1] != 1: + raise ValueError('The last dimension of r6_scale ('+str(r6_scale.shape[-1]) + + ') must be equal to 1 or the number of non-bonded atom pairs('+str(self.num_pairs)+')!') + self.r6_scale = Parameter(r6_scale, name='r6_scale_factor') + + r12_scale = Tensor(r12_scale, ms.float32) + if r12_scale.ndim == 0: + r12_scale = r12_scale.reshape(1, 1) + elif r12_scale.ndim == 1: + r12_scale = F.expand_dims(r12_scale, 0) + elif r12_scale.ndim > 2: + raise ValueError('The rank of r12_scale cannot be larger than 2!') + if r12_scale.shape[-1] != self.num_pairs and r12_scale.shape[-1] != 1: + raise ValueError('The last dimension of r12_scale ('+str(r12_scale.shape[-1]) + + ') must be equal to 1 or the number of non-bonded atom pairs('+str(self.num_pairs)+')!') + self.r12_scale = Parameter(r12_scale, name='r12_scale_factor') + + self.cutoff = None + if cutoff is not None: + self.cutoff = get_integer(cutoff) + + self.get_pairs_distance = AtomDistances( + self.index, use_pbc=use_pbc, length_unit=self.units) + + self.coulomb_const = self.units.coulomb + + self.concat = ops.Concat(-1) + + def set_pbc(self, use_pbc=None): + """ + Set whether to use periodic boundary condition. + + Args: + use_pbc (bool, optional): Whether to use periodic boundary condition. + If this is None, that means do not use periodic boundary condition. + Default: None. + """ + self.use_pbc = use_pbc + self.get_pairs_distance.set_pbc(use_pbc) + return self + + def set_cutoff(self, cutoff: float): + """ + Set cutoff distance. + + Args: + cutoff (float): Cutoff distance. Default: None. + """ + if cutoff is None: + self.cutoff = None + else: + self.cutoff = Tensor(cutoff, ms.float32) + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + inv_neigh_dis: Tensor = None, + pbc_box: Tensor = None, + ): + r""" + Calculate energy term + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour index. + neighbour_coord (Tensor): Tensor of shape (B, A, N). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. + inv_neigh_dis (Tensor): Tensor of shape (B, A, N). Data type is float. + Reciprocal of distances. + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + energy (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + p: Number of non-bonded atom pairs. + D: Dimension of the simulation system. Usually is 3. + """ + + distance = self.get_pairs_distance(coordinate, pbc_box) * self.input_unit_scale + # (B,p) + inv_dis = msnp.reciprocal(distance) + + # (B,p) = (1,p) * (B,p) * (1,p) + # A * k * qi * qj / r + energy_r = self.coulomb_const * self.qiqj * inv_dis * self.r_scale + + # \sigma_ij / r_ij + sigma_over_rij = self.sigma_ij * inv_dis + # (\sigma_ij / r_ij) ^ 6 + sigma_over_rij_6 = F.pows(sigma_over_rij, 6) + + ene_r6 = 4 * self.epsilon_ij * sigma_over_rij_6 + # -B * 4 * \epsilon * (\sigma_ij / r_ij) ^ 6 + energy_r6 = -ene_r6 * self.r6_scale + # C * 4 * \epsilon * (\sigma_ij / r_ij) ^ 12 + energy_r12 = ene_r6 * sigma_over_rij_6 * self.r12_scale + + # (B,1) <- (B,p) + energy_r = keepdim_sum(energy_r, -1) + energy_r6 = keepdim_sum(energy_r6, -1) + energy_r12 = keepdim_sum(energy_r12, -1) + + # (B, 1) + energy = energy_r + energy_r6 + energy_r12 + + return energy diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/forcefield.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/forcefield.py new file mode 100644 index 0000000000000000000000000000000000000000..85f473bcd204f3cdab2b0ad4b85869feced53b8c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/forcefield.py @@ -0,0 +1,421 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Force filed""" +import os +import copy +from typing import Union +import numpy as np +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor +from mindspore import ops +from mindspore.nn import CellList + +from .energy import EnergyCell, BondEnergy, AngleEnergy, DihedralEnergy, NonbondPairwiseEnergy +from .energy import CoulombEnergy, LennardJonesEnergy +from .potential import PotentialCell +from ..data.parameters import ForceFieldParameters +from ..data.forcefield import get_forcefield +from ..system import Molecule +from ..function.units import Units + + +THIS_PATH = os.path.abspath(__file__) +BUILTIN_FF_PATH = THIS_PATH.replace('potential/forcefield.py', 'data/forcefield/') + + +class ForceFieldBase(PotentialCell): + r""" + Basic cell for force filed. + + Args: + energy (Union[EnergyCell, list]): Energy terms. The type of energy parameter can be list or EnergyCell. + Default: None. + cutoff (float): Cutoff distance. Default: None. + exclude_index (Tensor): Tensor of shape (B, A, Ex). Data type is int. + The indexes of atoms that should be excluded from neighbour list. + Default: None. + length_unit (str): Length unit for position coordinate. Default: None. + energy_unit (str): Energy unit. Default: None. + units (Units): Units of length and energy. Default: None. + use_pbc (bool, optional): Whether to use periodic boundary condition. + If this is "None", that means do not use periodic boundary condition. + Default: None. + + Returns: + potential (Tensor), Tensor of shape (B, 1). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + energy: Union[EnergyCell, list] = None, + cutoff: float = None, + exclude_index: Tensor = None, + length_unit: str = None, + energy_unit: str = None, + units: Units = None, + use_pbc: bool = None, + ): + + super().__init__( + cutoff=cutoff, + exclude_index=exclude_index, + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + use_pbc=use_pbc, + ) + + self.num_energy = 0 + self.energy_cell = self.set_energy_cell(energy) + + self.energy_scale = 1 + + self.output_unit_scale = self.set_unit_scale() + + self.concat = ops.Concat(-1) + + def set_energy_scale(self, scale: Tensor): + """ + Set energy scale. + + Args: + scale (Tensor): Tensor of shape(B, 1). The scale parameter is used to set energy scale. + """ + scale = Tensor(scale, ms.float32) + if scale.ndim != 1 and scale.ndim != 0: + raise ValueError('The rank of energy scale must be 0 or 1.') + if scale.shape[-1] != self.output_dim and scale.shape[-1] != 1: + raise ValueError('The dimension of energy scale must be equal to the dimension of energy ' + + str(self.output_dim)+' or 1, but got: '+str(scale.shape[-1])) + self.energy_scale = scale + return self + + def set_energy_cell(self, energy: EnergyCell) -> CellList: + """ + Set energy. + + Args: + energy (Union[EnergyCell, list]): Energy terms. The type of energy parameter can be list or EnergyCell. + Default: None. + + Returns: + CellList. + """ + if energy is None: + return None + if isinstance(energy, EnergyCell): + self.num_energy = 1 + energy = CellList([energy]) + elif isinstance(energy, list): + self.num_energy = len(energy) + energy = CellList(energy) + else: + raise TypeError( + 'The type of energy must be EnergyCell or list but got: '+str(type(energy))) + + self.output_dim = 0 + if energy is not None: + for i in range(self.num_energy): + self.output_dim += energy[i].output_dim + return energy + + def set_unit_scale(self) -> Tensor: + """ + set unit scale. + + Returns: + Tensor, output unit scale. + """ + if self.energy_cell is None: + return 1 + output_unit_scale = () + for i in range(self.num_energy): + self.energy_cell[i].set_input_unit(self.units) + dim = self.energy_cell[i].output_dim + scale = np.ones((dim,), np.float32) * \ + self.energy_cell[i].convert_energy_to(self.units) + output_unit_scale += (scale,) + output_unit_scale = np.concatenate(output_unit_scale, axis=-1) + return Tensor(output_unit_scale, ms.float32) + + def set_units(self, length_unit: str = None, energy_unit: str = None, units: Units = None): + """ + Set units. + + Args: + length_unit (str): Length unit for position coordinate. Default: None. + energy_unit (str): Energy unit. Default: None. + units (Units): Units of length and energy. Default: None. + """ + if units is not None: + self.units.set_units(units=units) + else: + if length_unit is not None: + self.units.set_length_unit(length_unit) + if energy_unit is not None: + self.units.set_energy_unit(energy_unit) + + self.output_unit_scale = self.set_unit_scale() + + return self + + def set_pbc(self, use_pbc: bool = None): + """ + Set whether to use periodic boundary condition. + + Args: + use_pbc (bool, optional): Whether to use periodic boundary condition. + If this is "None", that means do not use periodic boundary condition. + Default: None. + """ + for i in range(self.num_energy): + self.energy_cell[i].set_pbc(use_pbc) + return self + + def set_cutoff(self, cutoff: Tensor = None): + """ + Set cutoff distance. + + Args: + cutoff (Tensor): Cutoff distance. Default: None. + """ + self.cutoff = None + if cutoff is not None: + self.cutoff = Tensor(cutoff, ms.float32) + for i in range(self.num_energy): + self.energy_cell[i].set_cutoff(self.cutoff) + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + pbc_box: Tensor = None + ): + r""" + Calculate potential energy. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. Default: None + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour atoms. Default: None + neighbour_coord (Tensor): Tensor of shape (B, A, N, D). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distance (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. Default: None + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + potential (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + N: Maximum number of neighbour atoms. + D: Dimension of the simulation system. Usually is 3. + """ + + inv_neigh_dis = 0 + inv_neigh_dis = msnp.reciprocal(neighbour_distance) + if neighbour_mask is not None: + inv_neigh_dis = msnp.where(neighbour_mask, inv_neigh_dis, 0) + + potential = () + for i in range(self.num_energy): + ene = self.energy_cell[i]( + coordinate=coordinate, + neighbour_index=neighbour_index, + neighbour_mask=neighbour_mask, + neighbour_coord=neighbour_coord, + neighbour_distance=neighbour_distance, + inv_neigh_dis=inv_neigh_dis, + pbc_box=pbc_box + ) + potential += (ene,) + + potential = self.concat(potential) * self.energy_scale * self.output_unit_scale + + return potential + + +class ForceField(ForceFieldBase): + r""" + Potential of classical force field. + + Args: + system (Molecule): Simulation system. + parameters (Union[dict, str]): Force field parameters. + cutoff (float): Cutoff distance. Default: None. + length_unit (str): Length unit for position coordinate. Default: None. + energy_unit (str): Energy unit. Default: None. + units (Units): Units of length and energy. Default: None. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + system: Molecule, + parameters: Union[dict, str], + cutoff: float = None, + length_unit: str = None, + energy_unit: str = None, + units: Units = None, + ): + + super().__init__( + cutoff=cutoff, + exclude_index=None, + length_unit=length_unit, + energy_unit=energy_unit, + units=units, + ) + + use_pbc = system.use_pbc + + # Generate Forcefield Parameters + parameters, template = get_forcefield(parameters) + for residue in system.residue: + residue.build_atom_type(template.get(residue.name)) + residue.build_atom_charge(template.get(residue.name)) + + system.build_system() + + ff_params = ForceFieldParameters( + system.atom_type, copy.deepcopy(parameters), atom_names=system.atom_name, + atom_charges=self.identity(system.atom_charge).asnumpy()) + + if isinstance(system.bond, np.ndarray): + system_params = ff_params(system.bond) + if isinstance(system.bond, Tensor): + system_params = ff_params(system.bond.asnumpy()) + + energy = [] + + # Bond energy + if system_params.bond_params is not None: + bond_index = system_params.bond_params['bond_index'] + bond_force_constant = system_params.bond_params['force_constant'] + bond_length = system_params.bond_params['bond_length'] + + bond_params: dict = parameters.get('bond_energy') + length_unit = bond_params.get('length_unit') + energy_unit = bond_params.get('energy_unit') + bond_energy = BondEnergy(bond_index, force_constant=bond_force_constant, + bond_length=bond_length, use_pbc=use_pbc, + length_unit=length_unit, energy_unit=energy_unit) + energy.append(bond_energy) + + # Angle energy + if system_params.angle_params is not None: + angle_index = system_params.angle_params['angle_index'] + angle_force_constant = system_params.angle_params['force_constant'] + bond_angle = system_params.angle_params['bond_angle'] + + angle_params: dict = parameters.get('angle_energy') + energy_unit = angle_params.get('energy_unit') + angle_energy = AngleEnergy(angle_index, force_constant=angle_force_constant, + bond_angle=bond_angle, use_pbc=use_pbc, energy_unit=energy_unit) + energy.append(angle_energy) + + # Dihedral energy + if system_params.dihedral_params is not None: + dihedral_index = Tensor(system_params.dihedral_params['dihedral_index'][None, :], ms.int32) + dihe_force_constant = Tensor(system_params.dihedral_params['force_constant'][None, :], ms.float32) + periodicity = Tensor(system_params.dihedral_params['periodicity'][None, :], ms.int32) + phase = Tensor(system_params.dihedral_params['phase'][None, :], ms.float32) + + # improper Parameters + improper_index = Tensor(system_params.improper_params['improper_index'][None, :], ms.int32) + + # Appending dihedral parameters and improper dihedral parameters. + dihedral_index = msnp.append(dihedral_index, improper_index, axis=1) + dihe_force_constant = msnp.append(dihe_force_constant, Tensor( + system_params.improper_params['force_constant'][None, :], ms.float32), axis=1) + periodicity = msnp.append(periodicity, Tensor( + system_params.improper_params['periodicity'][None, :], ms.int32), axis=1) + phase = msnp.append(phase, Tensor( + system_params.improper_params['phase'][None, :], ms.float32), axis=1) + + dihedral_params: dict = parameters.get('dihedral_energy') + energy_unit = dihedral_params.get('energy_unit') + dihedral_energy = DihedralEnergy(dihedral_index, force_constant=dihe_force_constant, + periodicity=periodicity, phase=phase, use_pbc=use_pbc, + energy_unit=energy_unit) + energy.append(dihedral_energy) + + # Electronic energy + if system.atom_charge is not None: + coulomb_params: dict = parameters.get('coulomb_energy') + length_unit = coulomb_params.get('length_unit') + energy_unit = coulomb_params.get('energy_unit') + ele_energy = CoulombEnergy(atom_charge=system.atom_charge, use_pbc=use_pbc, + length_unit=length_unit, energy_unit=energy_unit) + energy.append(ele_energy) + + # VDW energy + epsilon = None + sigma = None + if system_params.vdw_param is not None: + epsilon = system_params.vdw_param['epsilon'] + sigma = system_params.vdw_param['sigma'] + mean_c6 = system_params.vdw_param['mean_c6'] + + vdw_params: dict = parameters.get('vdw_energy') + length_unit = vdw_params.get('length_unit') + energy_unit = vdw_params.get('energy_unit') + vdw_energy = LennardJonesEnergy(epsilon=epsilon, sigma=sigma, mean_c6=mean_c6, use_pbc=use_pbc, + length_unit=length_unit, energy_unit=energy_unit) + energy.append(vdw_energy) + + # Non-bonded pairwise energy + if system_params.pair_params is not None and system_params.pair_params is not None: + pair_index = Tensor(ff_params.pair_index[None, :], ms.int32) + qiqj = system_params.pair_params['qiqj'] + epsilon_ij = system_params.pair_params['epsilon_ij'] + sigma_ij = system_params.pair_params['sigma_ij'] + r_scale = system_params.pair_params['r_scale'] + r6_scale = system_params.pair_params['r6_scale'] + r12_scale = system_params.pair_params['r12_scale'] + + pair_params: dict = parameters.get('nb_pair_energy') + length_unit = pair_params.get('length_unit') + energy_unit = pair_params.get('energy_unit') + pair_energy = NonbondPairwiseEnergy(pair_index, qiqj=qiqj, epsilon_ij=epsilon_ij, sigma_ij=sigma_ij, + r_scale=r_scale, r6_scale=r6_scale, r12_scale=r12_scale, + length_unit=length_unit, energy_unit=energy_unit, use_pbc=use_pbc) + energy.append(pair_energy) + + # Exclude Parameters + self._exclude_index = Tensor(system_params.excludes[None, :], ms.int32) + self.energy_cell = self.set_energy_cell(energy) + self.output_unit_scale = self.set_unit_scale() diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/potential/potential.py b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/potential.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1f00772e396e2381cc41bcbadc26d495867f15 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/potential/potential.py @@ -0,0 +1,202 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Potential""" + +import mindspore as ms +from mindspore import Tensor, Parameter +from mindspore import ops +from mindspore.nn import Cell +from mindspore.ops import functional as F + +from ..function.functions import gather_vectors +from ..function.operations import GetDistance, GetVector +from ..function.units import Units, global_units + + +class PotentialCell(Cell): + r""" + Basic cell for potential energy. + + Args: + cutoff (float): Cutoff distance. Default: None. + exclude_index (Tensor): Tensor of shape (B, A, Ex). Data type is int. + Index of the atoms should be excluded from non-bond interaction. + Default: None. + length_unit (str): Length unit for position coordinates. Default: None. + energy_unit (str): Energy unit. Default: None. + units (Units): Units of length and energy. Default: None. + use_pbc (bool, optional): Whether to use periodic boundary condition. + If this is None, that means do not use periodic boundary condition. + Default: None. + + Returns: + potential (Tensor), Tensor of shape (B, 1). Data type is float. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + cutoff: float = None, + exclude_index: Tensor = None, + length_unit: str = None, + energy_unit: str = None, + units: Units = None, + use_pbc: bool = None, + ): + + super().__init__() + + if units is None: + if length_unit is None and energy_unit is None: + self.units = global_units + else: + self.units = Units(length_unit, energy_unit) + else: + if not isinstance(units, Units): + raise TypeError('The type of units must be "Unit" but get type: '+str(type(units))) + self.units = units + + self.output_dim = 1 + + self.cutoff = None + if cutoff is not None: + self.cutoff = Tensor(cutoff, ms.float32) + + self.use_pbc = use_pbc + self._exclude_index = self._check_exclude_index(exclude_index) + + self.get_vector = GetVector(use_pbc) + self.get_distance = GetDistance(use_pbc) + self.gather_atoms = gather_vectors + + self.identity = ops.Identity() + + @property + def exclude_index(self) -> Tensor: + """ + exclude index. + + Returns: + Tensor, exclude index. + """ + if self._exclude_index is None: + return None + return self.identity(self._exclude_index) + + def _check_exclude_index(self, exclude_index: Tensor): + """check excluded index.""" + if exclude_index is None: + return None + exclude_index = Tensor(exclude_index, ms.int32) + if exclude_index.ndim == 2: + exclude_index = F.expand_dims(exclude_index, 0) + if exclude_index.ndim != 3: + raise ValueError('The rank of exclude_index must be 2 or 3 but got: ' + + str(exclude_index.shape)) + # (B,A,Ex) + return Parameter(exclude_index, name='exclude_index', requires_grad=False) + + def set_exclude_index(self, exclude_index: Tensor): + """ + Set excluded index. + + Args: + exclude_index (Tensor): Tensor of shape (B, A, Ex). Data type is int. + Index of the atoms should be excluded from non-bond interaction. + Default: None. + """ + self._exclude_index = self._check_exclude_index(exclude_index) + return self + + @property + def length_unit(self): + return self.units.length_unit + + @property + def energy_unit(self): + return self.units.energy_unit + + def set_pbc(self, use_pbc: bool = None): + """ + Set PBC box. + + Args: + use_pbc (bool, optional): Whether to use periodic boundary condition. + If this is None, that means do not use periodic boundary condition. + Default: None. + """ + self.use_pbc = use_pbc + self.get_vector.set_pbc(use_pbc) + self.get_distance.set_pbc(use_pbc) + return self + + def set_cutoff(self, cutoff: Tensor = None): + """ + Set cutoff distance. + + Args: + cutoff (Tensor): Cutoff distance. Default: None + """ + self.cutoff = None + if cutoff is not None: + self.cutoff = Tensor(cutoff, ms.float32) + return self + + def construct(self, + coordinate: Tensor, + neighbour_index: Tensor = None, + neighbour_mask: Tensor = None, + neighbour_coord: Tensor = None, + neighbour_distance: Tensor = None, + pbc_box: Tensor = None + ): + r"""Calculate potential energy. + + Args: + coordinates (Tensor): Tensor of shape (B, A, D). Data type is float. + Position coordinate of atoms in system. + neighbour_index (Tensor): Tensor of shape (B, A, N). Data type is int. + Index of neighbour atoms. Default: None + neighbour_mask (Tensor): Tensor of shape (B, A, N). Data type is bool. + Mask for neighbour atoms. Default: None + neighbour_coord (Tensor): Tensor of shape (B, A, N, D). Data type is bool. + Position coorindates of neighbour atoms. + neighbour_distances (Tensor): Tensor of shape (B, A, N). Data type is float. + Distance between neighbours atoms. Default: None + pbc_box (Tensor): Tensor of shape (B, D). Data type is float. + Tensor of PBC box. Default: None + + Returns: + potential (Tensor), Tensor of shape (B, 1). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation + A: Number of atoms. + N: Maximum number of neighbour atoms. + D: Dimension of the simulation system. Usually is 3. + + """ + #pylint: disable=invalid-name + + raise NotImplementedError diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf6a34011ab5702b1a3bb26a9ac1f40d7ba4d94 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Molecular system""" + +from .molecule import Molecule, Protein +from .residue import Residue, AminoAcid + +__all__ = ['Molecule', 'Protein', 'Residue', 'AminoAcid'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..491c0d9dc857b65d2a91ec1674667e9653080a2c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Modeling""" + +from .add_missing_atoms import rotate_by_axis, add_h +from .hadder import AddHydrogen, ReadPdbByMindsponge +from .pdb_generator import gen_pdb +from .pdb_parser import read_pdb + +__all__ = ['rotate_by_axis', 'add_h', 'AddHydrogen', 'ReadPdbByMindsponge', 'gen_pdb', 'read_pdb'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/add_missing_atoms.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/add_missing_atoms.py new file mode 100644 index 0000000000000000000000000000000000000000..5eae39145f3ddfa004890e900b38bf3a279000f1 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/add_missing_atoms.py @@ -0,0 +1,127 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# + +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Add missing atoms module. +""" +import numpy as np + + +def rotate_by_axis(axis, theta): + """Rotate an atom by a given axis with angle theta. + Args: + axis: The rotate axis. + theta: The rotate angle. + Returns: + The rotate matrix. + """ + vx, vy, vz = axis[0], axis[1], axis[2] + return np.array([[vx*vx*(1-np.cos(theta))+np.cos(theta), + vx*vy*(1-np.cos(theta))-vz*np.sin(theta), + vx*vz*(1-np.cos(theta))+vy*np.sin(theta)], + [vx*vy*(1-np.cos(theta))+vz*np.sin(theta), + vy*vy*(1-np.cos(theta))+np.cos(theta), + vy*vz*(1-np.cos(theta))-vx*np.sin(theta)], + [vx*vz*(1-np.cos(theta))-vy*np.sin(theta), + vy*vz*(1-np.cos(theta))+vx*np.sin(theta), + vz*vz*(1-np.cos(theta))+np.cos(theta)]]) + + +def add_h(crd, atype=None, i=None, j=None, k=None): + """Add hydrogen once. + Args: + crd: The coordinates of all atoms. + atype: Different types correspond to different addH algorithms. + Indexes: + c6: Add one hydrogen at atom i. j and k atoms are connected to atom i. + """ + if atype is None: + raise ValueError('The type of AddH should not be None!') + + if atype != 'h2o' and i is None or j is None or k is None: + raise ValueError('3 atom indexes are need.') + + if atype == 'c6': + left_arrow = crd[j] - crd[i] + left_arrow /= np.linalg.norm(left_arrow) + right_arrow = crd[k] - crd[i] + right_arrow /= np.linalg.norm(right_arrow) + h_arrow = -1 * (left_arrow + right_arrow) + h_arrow /= np.linalg.norm(h_arrow) + return (h_arrow + crd[i])[None, :] + + if atype == 'dihedral': + h_arrow = crd[j] - crd[k] + h_arrow /= np.linalg.norm(h_arrow) + return (h_arrow + crd[i])[None, :] + + if atype == 'c2h4': + h_arrow_1 = crd[j] - crd[k] + h1 = (h_arrow_1/np.linalg.norm(h_arrow_1) + crd[i])[None, :] + middle_arrow = (crd[i] - crd[j]) + middle_arrow /= np.linalg.norm(middle_arrow) + middle_arrow *= np.linalg.norm(h_arrow_1) + h_arrow_2 = -h_arrow_1 + middle_arrow + h2 = (h_arrow_2/np.linalg.norm(h_arrow_2) + crd[i])[None, :] + return np.append(h1, h2, axis=0) + + if atype == 'ch3': + upper_arrow = crd[k] - crd[j] + upper_arrow /= np.linalg.norm(upper_arrow) + h1 = -upper_arrow + crd[i] + axes = crd[j] - crd[i] + rotate_matrix = rotate_by_axis(axes, 2 * np.pi / 3) + h2 = np.dot(rotate_matrix, h1-crd[i]) + h2 /= np.linalg.norm(h2) + h2 += crd[i] + rotate_matrix = rotate_by_axis(axes, 4 * np.pi / 3) + h3 = np.dot(rotate_matrix, h1-crd[i]) + h3 /= np.linalg.norm(h3) + h3 += crd[i] + h12 = np.append(h1[None, :], h2[None, :], axis=0) + return np.append(h12, h3[None, :], axis=0) + + if atype == 'cc3': + h1 = crd[k] + upper_arrow = crd[j] - crd[i] + rotate_matrix = rotate_by_axis(upper_arrow, 2 * np.pi / 3) + h2 = np.dot(rotate_matrix, h1-crd[i]) + h2 /= np.linalg.norm(h2) + return (h2 + crd[i])[None, :] + + if atype == 'c2h2': + right_arrow = crd[k] - crd[i] + rotate_matrix = rotate_by_axis(right_arrow, 2 * np.pi / 3) + h1 = np.dot(rotate_matrix, crd[j]-crd[i]) + h2 = np.dot(rotate_matrix, h1) + h1 /= np.linalg.norm(h1) + h1 = (h1 + crd[i])[None, :] + h2 /= np.linalg.norm(h2) + h2 = (h2 + crd[i])[None, :] + return np.append(h1, h2, axis=0) + + if atype == 'h2o': + if i is None: + raise ValueError('The index of O atom should be given.') + + return None diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/hadder.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/hadder.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad165168c0c45837f048b123469c02f958185dd --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/hadder.py @@ -0,0 +1,730 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# + +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +H-Adder Module. +""" +import sys +import numpy as np +from .add_missing_atoms import add_h +from .pdb_generator import gen_pdb +from .pdb_parser import read_pdb + +hnames = {'ACE': {'CH3': ['H1', 'H2', 'H3']}, + 'ALA': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB1', 'HB2', 'HB3']}, + 'ARG': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'CD': ['HD2', 'HD3'], + 'NE': ['HE'], 'NH1': ['HH11', 'HH12'], 'NH2': ['HH21', 'HH22']}, + 'ASN': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'ND2': ['HD21', 'HD22']}, + 'ASP': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3']}, + 'CALA': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB1', 'HB2', 'HB3'], 'C': ['OXT']}, + 'CARG': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'CD': ['HD2', 'HD3'], + 'NE': ['HE'], 'NH1': ['HH11', 'HH12'], 'NH2': ['HH21', 'HH22'], 'C': ['OXT']}, + 'CASN': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'ND2': ['HD21', 'HD22'], 'C': ['OXT']}, + 'CASP': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'C': ['OXT']}, + 'CCYS': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'SG': ['HG'], 'C': ['OXT']}, + 'CGLN': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'NE2': ['HE21', 'HE22'], + 'C': ['OXT']}, + 'CGLU': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'C': ['OXT']}, + 'CGLY': {'N': ['H'], 'CA': ['HA2', 'HA3'], 'C': ['OXT']}, + 'CHID': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'ND1': ['HD1'], 'CE1': ['HE1'], 'CD2': ['HD2'], + 'C': ['OXT']}, + 'CHIS': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CE1': ['HE1'], 'NE2': ['HE2'], 'CD2': ['HD2'], + 'C': ['OXT']}, + 'CILE': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB'], 'CG2': ['HG21', 'HG22', 'HG23'], 'CG1': ['HG12', 'HG13'], + 'CD1': ['HD11', 'HD12', 'HD13'], 'C': ['OXT']}, + 'CLEU': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG'], 'CD1': ['HD11', 'HD12', 'HD13'], + 'CD2': ['HD21', 'HD22', 'HD23'], 'C': ['OXT']}, + 'CLYS': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'CD': ['HD2', 'HD3'], + 'CE': ['HE2', 'HE3'], 'NZ': ['HZ1', 'HZ2', 'HZ3'], 'C': ['OXT']}, + 'CMET': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'CE': ['HE1', 'HE2', 'HE3'], + 'C': ['OXT']}, + 'CPHE': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'CE1': ['HE1'], 'CZ': ['HZ'], + 'CE2': ['HE2'], 'CD2': ['HD2'], 'C': ['OXT']}, + 'CPRO': {'CD': ['HD2', 'HD3'], 'CG': ['HG2', 'HG3'], 'CB': ['HB2', 'HB3'], 'CA': ['HA'], 'C': ['OXT']}, + 'CSER': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'OG': ['HG'], 'C': ['OXT']}, + 'CTHR': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB'], 'CG2': ['HG21', 'HG22', 'HG23'], 'OG1': ['HG1'], + 'C': ['OXT']}, + 'CTRP': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'NE1': ['HE1'], 'CZ2': ['HZ2'], + 'CH2': ['HH2'], 'CZ3': ['HZ3'], 'CE3': ['HE3'], 'C': ['OXT']}, + 'CTYR': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'CE1': ['HE1'], 'OH': ['HH'], + 'CE2': ['HE2'], 'CD2': ['HD2'], 'C': ['OXT']}, + 'CVAL': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB'], 'CG1': ['HG11', 'HG12', 'HG13'], + 'CG2': ['HG21', 'HG22', 'HG23'], 'C': ['OXT']}, + 'CYS': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'SG': ['HG']}, + 'GLN': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'NE2': ['HE21', 'HE22']}, + 'GLU': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3']}, + 'GLY': {'N': ['H'], 'CA': ['HA2', 'HA3']}, + 'HID': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'ND1': ['HD1'], 'CE1': ['HE1'], 'CD2': ['HD2']}, + 'HIS': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CE1': ['HE1'], 'NE2': ['HE2'], 'CD2': ['HD2']}, + 'ILE': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB'], 'CG2': ['HG21', 'HG22', 'HG23'], 'CG1': ['HG12', 'HG13'], + 'CD1': ['HD11', 'HD12', 'HD13']}, + 'LEU': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG'], 'CD1': ['HD11', 'HD12', 'HD13'], + 'CD2': ['HD21', 'HD22', 'HD23']}, + 'LYS': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'CD': ['HD2', 'HD3'], + 'CE': ['HE2', 'HE3'], 'NZ': ['HZ1', 'HZ2', 'HZ3']}, + 'MET': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], 'CE': ['HE1', 'HE2', 'HE3']}, + 'NALA': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB1', 'HB2', 'HB3']}, + 'NARG': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], + 'CD': ['HD2', 'HD3'], 'NE': ['HE'], 'NH1': ['HH11', 'HH12'], 'NH2': ['HH21', 'HH22']}, + 'NASN': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'ND2': ['HD21', 'HD22']}, + 'NASP': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3']}, + 'NCYS': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'SG': ['HG']}, + 'NGLN': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], + 'NE2': ['HE21', 'HE22']}, + 'NGLU': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3']}, + 'NGLY': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA2', 'HA3']}, + 'NHID': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'ND1': ['HD1'], 'CE1': ['HE1'], + 'CD2': ['HD2']}, + 'NHIS': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CE1': ['HE1'], 'NE2': ['HE2'], + 'CD2': ['HD2']}, + 'NILE': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB'], 'CG2': ['HG21', 'HG22', 'HG23'], + 'CG1': ['HG12', 'HG13'], 'CD1': ['HD11', 'HD12', 'HD13']}, + 'NLEU': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG'], + 'CD1': ['HD11', 'HD12', 'HD13'], 'CD2': ['HD21', 'HD22', 'HD23']}, + 'NLYS': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], + 'CD': ['HD2', 'HD3'], 'CE': ['HE2', 'HE3'], 'NZ': ['HZ1', 'HZ2', 'HZ3']}, + 'NME': {'N': ['H'], 'CH3': ['HH31', 'HH32', 'HH33']}, + 'NMET': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CG': ['HG2', 'HG3'], + 'CE': ['HE1', 'HE2', 'HE3']}, + 'NPHE': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'CE1': ['HE1'], + 'CZ': ['HZ'], 'CE2': ['HE2'], 'CD2': ['HD2']}, + 'NPRO': {'N': ['H2', 'H3'], 'CD': ['HD2', 'HD3'], 'CG': ['HG2', 'HG3'], 'CB': ['HB2', 'HB3'], 'CA': ['HA']}, + 'NSER': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'OG': ['HG']}, + 'NTHR': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB'], 'CG2': ['HG21', 'HG22', 'HG23'], + 'OG1': ['HG1']}, + 'NTRP': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'NE1': ['HE1'], + 'CZ2': ['HZ2'], 'CH2': ['HH2'], 'CZ3': ['HZ3'], 'CE3': ['HE3']}, + 'NTYR': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'CE1': ['HE1'], + 'OH': ['HH'], 'CE2': ['HE2'], 'CD2': ['HD2']}, + 'NVAL': {'N': ['H1', 'H2', 'H3'], 'CA': ['HA'], 'CB': ['HB'], 'CG1': ['HG11', 'HG12', 'HG13'], + 'CG2': ['HG21', 'HG22', 'HG23']}, + 'PHE': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'CE1': ['HE1'], 'CZ': ['HZ'], + 'CE2': ['HE2'], 'CD2': ['HD2']}, + 'PRO': {'CD': ['HD2', 'HD3'], 'CG': ['HG2', 'HG3'], 'CB': ['HB2', 'HB3'], 'CA': ['HA']}, + 'SER': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'OG': ['HG']}, + 'THR': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB'], 'CG2': ['HG21', 'HG22', 'HG23'], 'OG1': ['HG1']}, + 'TRP': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'NE1': ['HE1'], 'CZ2': ['HZ2'], + 'CH2': ['HH2'], 'CZ3': ['HZ3'], 'CE3': ['HE3']}, + 'TYR': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB2', 'HB3'], 'CD1': ['HD1'], 'CE1': ['HE1'], 'OH': ['HH'], + 'CE2': ['HE2'], 'CD2': ['HD2']}, + 'VAL': {'N': ['H'], 'CA': ['HA'], 'CB': ['HB'], 'CG1': ['HG11', 'HG12', 'HG13'], + 'CG2': ['HG21', 'HG22', 'HG23']}, + } + +hbond_type = { + 'ACE': { + 'CH3': np.array(['ch3', 'C', 'O']) + }, + 'ALA': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'CB']), + 'CB': np.array(['ch3', 'CA', 'C']) + }, + 'ARG': { + 'N': np.array(['dihedral', 'CA', 'CB']), + 'CA': np.array(['cc3', 'CB', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'NE']), + 'NE': np.array(['c6', 'CD', 'CZ']), + 'NH1': np.array([['dihedral', 'CZ', 'NH2'], + ['dihedral', 'CZ', 'NE']]), + 'NH2': np.array([['dihedral', 'CZ', 'NH1'], + ['dihedral', 'CZ', 'NE']]) + }, + 'ASN': { + 'ND2': np.array(['c2h4', 'CG', 'OD1']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'CB']) + }, + 'ASP': { + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'CB']) + }, + 'CALA': { + 'N': np.array(['dihedral', 'CA', 'CB']), + 'CA': np.array(['cc3', 'CB', 'C']), + 'CB': np.array(['ch3', 'CA', 'C']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CARG': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'NE']), + 'NE': np.array(['dihedral', 'CZ', 'NH1']), + 'NH1': np.array([['dihedral', 'CZ', 'NH2'], + ['dihedral', 'CZ', 'NE']]), + 'NH2': np.array([['dihedral', 'CZ', 'NH1'], + ['dihedral', 'CZ', 'NE']]), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CASN': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'ND2': np.array(['c2h4', 'CG', 'OD1']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CASP': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CCYS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'SG']), + 'SG': np.array(['dihedral', 'CB', 'CA']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CGLN': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'NE2': np.array(['c2h4', 'CD', 'OE1']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CGLU': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CGLY': { + 'CA': np.array(['c2h2', 'N', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CHID': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD2': np.array(['c6', 'CG', 'NE2']), + 'ND1': np.array(['c6', 'CG', 'CE1']), + 'CE1': np.array(['c6', 'ND1', 'NE2']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CHIS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD2': np.array(['c6', 'CG', 'NE2']), + 'NE2': np.array(['c6', 'CD2', 'CE1']), + 'CE1': np.array(['c6', 'ND1', 'NE2']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CILE': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG1', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']), + 'CG1': np.array(['c2h2', 'CB', 'CD1']), + 'CD1': np.array(['ch3', 'CG1', 'CB']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CLEU': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['cc3', 'CD2', 'CB']), + 'CD2': np.array(['ch3', 'CG', 'CD1']), + 'CD1': np.array(['ch3', 'CG', 'CB']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CLYS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'CE']), + 'CE': np.array(['c2h2', 'CD', 'NZ']), + 'NZ': np.array(['ch3', 'CE', 'CD']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CMET': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'SD']), + 'CE': np.array(['ch3', 'SD', 'CG']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CPHE': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'CE1']), + 'CD2': np.array(['c6', 'CG', 'CE2']), + 'CE1': np.array(['c6', 'CD1', 'CZ']), + 'CE2': np.array(['c6', 'CD2', 'CZ']), + 'CZ': np.array(['c6', 'CE2', 'CE1']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CPRO': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'N']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CSER': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'OG']), + 'OG': np.array(['dihedral', 'CB', 'CA']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CTHR': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG2', 'OG1']), + 'OG1': np.array(['dihedral', 'CB', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CTRP': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'NE1']), + 'NE1': np.array(['c6', 'CD1', 'CE2']), + 'CZ2': np.array(['c6', 'CE2', 'CH2']), + 'CH2': np.array(['c6', 'CZ2', 'CZ3']), + 'CZ3': np.array(['c6', 'CH2', 'CE3']), + 'CE3': np.array(['c6', 'CD2', 'CZ3']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CTYR': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'CE1']), + 'CE1': np.array(['c6', 'CD1', 'CZ']), + 'CD2': np.array(['c6', 'CG', 'CE2']), + 'CE2': np.array(['c6', 'CD2', 'CZ']), + 'OH': np.array(['dihedral', 'CZ', 'CE2']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CVAL': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG2', 'CA']), + 'CG1': np.array(['ch3', 'CB', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']), + 'C': np.array(['dihedral', 'CA', 'N']) + }, + 'CYS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'SG']), + 'SG': np.array(['dihedral', 'CB', 'CA']) + }, + 'GLN': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'NE2': np.array(['c2h4', 'CD', 'OE1']) + }, + 'GLU': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']) + }, + 'GLY': { + 'CA': np.array(['c2h2', 'N', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + }, + 'HID': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD2': np.array(['c6', 'CG', 'NE2']), + 'ND1': np.array(['c6', 'CG', 'CE1']), + 'CE1': np.array(['c6', 'ND1', 'NE2']) + }, + 'HIS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD2': np.array(['c6', 'CG', 'NE2']), + 'NE2': np.array(['c6', 'CD2', 'CE1']), + 'CE1': np.array(['c6', 'ND1', 'NE2']) + }, + 'ILE': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG1', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']), + 'CG1': np.array(['c2h2', 'CB', 'CD1']), + 'CD1': np.array(['ch3', 'CG1', 'CB']) + }, + 'LEU': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['cc3', 'CD2', 'CB']), + 'CD2': np.array(['ch3', 'CG', 'CD1']), + 'CD1': np.array(['ch3', 'CG', 'CB']) + }, + 'LYS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'CE']), + 'CE': np.array(['c2h2', 'CD', 'NZ']), + 'NZ': np.array(['ch3', 'CE', 'CD']) + }, + 'MET': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'SD']), + 'CE': np.array(['ch3', 'SD', 'CG']) + }, + 'NALA': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'CB']), + 'CB': np.array(['ch3', 'CA', 'C']) + }, + 'NARG': { + 'N': np.array(['ch3', 'CA', 'C']), + 'CA': np.array(['cc3', 'CB', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'NE']), + 'NE': np.array(['dihedral', 'CZ', 'NH1']), + 'NH1': np.array([['dihedral', 'CZ', 'NH2'], + ['dihedral', 'CZ', 'NE']]), + 'NH2': np.array([['dihedral', 'CZ', 'NH1'], + ['dihedral', 'CZ', 'NE']]) + }, + 'NASN': { + 'ND2': np.array(['c2h4', 'CG', 'OD1']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']) + }, + 'NASP': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']) + }, + 'NCYS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'SG']), + 'SG': np.array(['dihedral', 'CB', 'CA']) + }, + 'NGLN': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'NE2': np.array(['c2h4', 'CD', 'OE1']) + }, + 'NGLU': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']) + }, + 'NGLY': { + 'CA': np.array(['c2h2', 'N', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + }, + 'NHID': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD2': np.array(['c6', 'CG', 'NE2']), + 'ND1': np.array(['c6', 'CG', 'CE1']), + 'CE1': np.array(['c6', 'ND1', 'NE2']) + }, + 'NHIS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD2': np.array(['c6', 'CG', 'NE2']), + 'NE2': np.array(['c6', 'CD2', 'CE1']), + 'CE1': np.array(['c6', 'ND1', 'NE2']) + }, + 'NILE': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG1', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']), + 'CG1': np.array(['c2h2', 'CB', 'CD1']), + 'CD1': np.array(['ch3', 'CG1', 'CB']) + }, + 'NLEU': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['cc3', 'CD2', 'CB']), + 'CD2': np.array(['ch3', 'CG', 'CD1']), + 'CD1': np.array(['ch3', 'CG', 'CB']) + }, + 'NLYS': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'CE']), + 'CE': np.array(['c2h2', 'CD', 'NZ']), + 'NZ': np.array(['ch3', 'CE', 'CD']) + }, + 'NMET': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'SD']), + 'CE': np.array(['ch3', 'SD', 'CG']) + }, + 'NPHE': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'CE1']), + 'CD2': np.array(['c6', 'CG', 'CE2']), + 'CE1': np.array(['c6', 'CD1', 'CZ']), + 'CE2': np.array(['c6', 'CD2', 'CZ']), + 'CZ': np.array(['c6', 'CE2', 'CE1']), + }, + 'NPRO': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'N']), + 'N': np.array(['c2h2', 'CA', 'CD']) + }, + 'NSER': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'OG']), + 'OG': np.array(['dihedral', 'CB', 'CA']) + }, + 'NTHR': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG2', 'OG1']), + 'OG1': np.array(['dihedral', 'CB', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']) + }, + 'NTRP': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'NE1']), + 'NE1': np.array(['c6', 'CD1', 'CE2']), + 'CZ2': np.array(['c6', 'CE2', 'CH2']), + 'CH2': np.array(['c6', 'CZ2', 'CZ3']), + 'CZ3': np.array(['c6', 'CH2', 'CE3']), + 'CE3': np.array(['c6', 'CD2', 'CZ3']) + }, + 'NTYR': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'CE1']), + 'CE1': np.array(['c6', 'CD1', 'CZ']), + 'CD2': np.array(['c6', 'CG', 'CE2']), + 'CE2': np.array(['c6', 'CD2', 'CZ']), + 'OH': np.array(['dihedral', 'CZ', 'CE2']) + }, + 'NVAL': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['ch3', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG2', 'CA']), + 'CG1': np.array(['ch3', 'CB', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']) + }, + 'PHE': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'CE1']), + 'CD2': np.array(['c6', 'CG', 'CE2']), + 'CE1': np.array(['c6', 'CD1', 'CZ']), + 'CE2': np.array(['c6', 'CD2', 'CZ']), + 'CZ': np.array(['c6', 'CE2', 'CE1']), + }, + 'PRO': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CG': np.array(['c2h2', 'CB', 'CD']), + 'CD': np.array(['c2h2', 'CG', 'N']), + }, + 'SER': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'OG']), + 'OG': np.array(['dihedral', 'CB', 'CA']) + }, + 'THR': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG2', 'OG1']), + 'OG1': np.array(['dihedral', 'CB', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']) + }, + 'TRP': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'NE1']), + 'NE1': np.array(['c6', 'CD1', 'CE2']), + 'CZ2': np.array(['c6', 'CE2', 'CH2']), + 'CH2': np.array(['c6', 'CZ2', 'CZ3']), + 'CZ3': np.array(['c6', 'CH2', 'CE3']), + 'CE3': np.array(['c6', 'CD2', 'CZ3']) + }, + 'TYR': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['c2h2', 'CA', 'CG']), + 'CD1': np.array(['c6', 'CG', 'CE1']), + 'CE1': np.array(['c6', 'CD1', 'CZ']), + 'CD2': np.array(['c6', 'CG', 'CE2']), + 'CE2': np.array(['c6', 'CD2', 'CZ']), + 'OH': np.array(['dihedral', 'CZ', 'CE2']) + }, + 'VAL': { + 'CA': np.array(['cc3', 'CB', 'C']), + 'N': np.array(['dihedral', 'CA', 'C']), + 'CB': np.array(['cc3', 'CG2', 'CA']), + 'CG1': np.array(['ch3', 'CB', 'CA']), + 'CG2': np.array(['ch3', 'CB', 'CA']) + }, + 'NME': None +} + +addhs = {'c6': 1, + 'dihedral': 1, + 'c2h4': 2, + 'ch3': 3, + 'cc3': 1, + 'c2h2': 2} + +sys.path.append('../') + + +def AddHydrogen(pdb_in, pdb_out): + """ The API function for adding Hydrogen. + Args: + pdb_in(str): The input pdb file name, absolute file path is suggested. + pdb_out(str): The output pdb file name, absolute file path is suggested. + """ + pdb_name = pdb_in + new_pdb_name = pdb_out + atom_names, res_names, _, crds, _, _, _, _, _, \ + _, _, _, _, _ = read_pdb(pdb_name, ignoreh=True) + + for i, res in enumerate(res_names): + res = 'N' * (i == 0) + res + res = 'C' * (i == (len(res_names) - 1)) + res + h_names = [] + crds[i] = np.array(crds[i]) + for atom in atom_names[i]: + if atom == 'C' and len(res) == 4 and res.startswith( + 'C') and np.isin(atom_names[i], 'OXT').sum() == 1: + continue + if atom in hbond_type[res].keys() and len( + hbond_type[res][atom].shape) == 1: + addh_type = hbond_type[res][atom][0] + for name in hnames[res][atom]: + h_names.append(name) + try: + m = np.where(np.array(atom_names[i]) == [atom])[0][0] + n = np.where( + np.array( + atom_names[i]) == hbond_type[res][atom][1])[0][0] + o = np.where( + np.array( + atom_names[i]) == hbond_type[res][atom][2])[0][0] + except IndexError as e: + raise ValueError( + 'Some heavy atoms are missing in given pdb file.') from e + new_crd = add_h(np.array(crds[i]), + atype=addh_type, + i=m, + j=n, + k=o) + crds[i] = np.append(crds[i], new_crd, axis=0) + elif atom in hbond_type[res].keys(): + for j, hbond in enumerate(hbond_type[res][atom]): + addh_type = hbond[0] + h_names.append(hnames[res][atom][j]) + try: + m = np.where(np.array(atom_names[i]) == [atom])[0][0] + n = np.where(np.array(atom_names[i]) == hbond[1])[0][0] + o = np.where(np.array(atom_names[i]) == hbond[2])[0][0] + except IndexError as e: + raise ValueError( + 'Some heavy atoms are missing in given pdb file.') from e + new_crd = add_h(np.array(crds[i]), + atype=addh_type, + i=m, + j=n, + k=o) + crds[i] = np.append(crds[i], new_crd, axis=0) + else: + continue + for name in h_names: + atom_names[i].append(name) + + new_crds = crds[0] + for crd in crds[1:]: + new_crds = np.append(new_crds, crd, axis=0) + + new_atom_names = np.array(atom_names[0]) + for name in atom_names[1:]: + new_atom_names = np.append(new_atom_names, name) + + new_res_names = [] + new_res_ids = [] + for i, crd in enumerate(crds): + for _ in range(crd.shape[0]): + new_res_names.append(res_names[i]) + new_res_ids.append(i + 1) + + gen_pdb(new_crds[None, :], new_atom_names, + new_res_names, new_res_ids, new_pdb_name) + print('1 H-Adding task complete.') + +def ReadPdbByMindsponge(pdb_name, addh): + if addh: + t_name = pdb_name.replace('.pdb', '_addH_by_mindsponge.pdb') + AddHydrogen(pdb_name, t_name) + return read_pdb(t_name) + + return read_pdb(pdb_name) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/pdb_generator.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/pdb_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1ad9443541fafc696e817594bf712989486c53 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/pdb_generator.py @@ -0,0 +1,70 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# + +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Module used to generate a pdb file via crd and res names. +""" + +import os + + +def gen_pdb(crd, atom_names, res_names, res_ids, pdb_name='temp.pdb'): + """Write protein crd information into pdb format files. + Args: + crd(numpy.float32): The coordinates of protein atoms. + atom_names(numpy.str_): The atom names differ from aminos. + res_names(numpy.str_): The residue names of amino names. + res_ids(numpy.int32): A unique mask each same residue. + pdb_name(str): The path to save the pdb file, absolute path is suggested. + """ + success = 1 + file = os.open(pdb_name, os.O_RDWR | os.O_CREAT) + pdb = os.fdopen(file, "w") + + pdb.write('MODEL 1\n') + for i, c in enumerate(crd[0]): + pdb.write('ATOM'.ljust(6)) + pdb.write('{}'.format(i + 1).rjust(5)) + if len(atom_names[i]) < 4: + pdb.write(' ') + pdb.write(atom_names[i].ljust(3)) + else: + pdb.write(' ') + pdb.write(atom_names[i].ljust(4)) + pdb.write(res_names[i].rjust(4)) + pdb.write('A'.rjust(2)) + pdb.write('{}'.format(res_ids[i]).rjust(4)) + pdb.write(' ') + pdb.write('{:.3f}'.format(c[0]).rjust(8)) + pdb.write('{:.3f}'.format(c[1]).rjust(8)) + pdb.write('{:.3f}'.format(c[2]).rjust(8)) + pdb.write('1.0'.rjust(6)) + pdb.write('0.0'.rjust(6)) + pdb.write('{}'.format(atom_names[i][0]).rjust(12)) + pdb.write('\n') + pdb.write('TER\n') + pdb.write('ENDMDL\n') + pdb.write('END\n') + + pdb.close() + return success diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/pdb_parser.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/pdb_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2aab88a19a5452858332a777803d4664415a17 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/modeling/pdb_parser.py @@ -0,0 +1,382 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# + +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Read information from a pdb format file. +""" +import numpy as np + +restypes = [ + 'A', 'R', 'N', 'D', 'C', + 'Q', 'E', 'G', 'H', 'I', + 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V' +] +resdict = {'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4, + 'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9, + 'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14, + 'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19, + 'CALA': 0, 'CARG': 1, 'CASN': 2, 'CASP': 3, 'CCYS': 4, + 'CGLN': 5, 'CGLU': 6, 'CGLY': 7, 'CHIS': 8, 'CILE': 9, + 'CLEU': 10, 'CLYS': 11, 'CMET': 12, 'CPHE': 13, 'CPRO': 14, + 'CSER': 15, 'CTHR': 16, 'CTRP': 17, 'CTYR': 18, 'CVAL': 19, + 'NALA': 0, 'NARG': 1, 'NASN': 2, 'NASP': 3, 'NCYS': 4, + 'NGLN': 5, 'NGLU': 6, 'NGLY': 7, 'NHIS': 8, 'NILE': 9, + 'NLEU': 10, 'NLYS': 11, 'NMET': 12, 'NPHE': 13, 'NPRO': 14, + 'NSER': 15, 'NTHR': 16, 'NTRP': 17, 'NTYR': 18, 'NVAL': 19, + 'CHIE': 8, 'HIE': 8, 'NHIE': 8, 'WAT': 22 + } + +atom_types = [ + 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + 'CZ3', 'NZ', 'OXT' +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], + 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], + 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], + 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], + 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], + 'WAT': ['OW', '', '', '', '', '', '', '', '', '', '', '', '', ''], +} +restype_name_to_atom14_masks = { + 'ALA': [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'ARG': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'ASN': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + 'ASP': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + 'CYS': [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + 'GLN': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + 'GLU': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + 'GLY': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'HIS': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + 'HIE': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + 'ILE': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + 'LEU': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + 'LYS': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + 'MET': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + 'PHE': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'PRO': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + 'SER': [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + 'THR': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + 'TRP': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + 'TYR': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + 'VAL': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + 'UNK': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'WAT': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} + +atom14_order_dict = {'ALA': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4}, + 'ARG': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'CD': 6, + 'NE': 7, + 'CZ': 8, + 'NH1': 9, + 'NH2': 10}, + 'ASN': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'OD1': 6, + 'ND2': 7}, + 'ASP': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'OD1': 6, + 'OD2': 7}, + 'CYS': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'SG': 5}, + 'GLN': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'CD': 6, + 'OE1': 7, + 'NE2': 8}, + 'GLU': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'CD': 6, + 'OE1': 7, + 'OE2': 8}, + 'GLY': {'N': 0, 'CA': 1, 'C': 2, 'O': 3}, + 'HIS': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'ND1': 6, + 'CD2': 7, + 'CE1': 8, + 'NE2': 9}, + 'HIE': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'ND1': 6, + 'CD2': 7, + 'CE1': 8, + 'NE2': 9}, + 'ILE': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG1': 5, + 'CG2': 6, + 'CD1': 7}, + 'LEU': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'CD1': 6, + 'CD2': 7}, + 'LYS': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'CD': 6, + 'CE': 7, + 'NZ': 8}, + 'MET': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'SD': 6, 'CE': 7}, + 'PHE': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'CD1': 6, + 'CD2': 7, + 'CE1': 8, + 'CE2': 9, + 'CZ': 10}, + 'PRO': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6}, + 'SER': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'OG': 5}, + 'THR': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'OG1': 5, 'CG2': 6}, + 'TRP': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'CD1': 6, + 'CD2': 7, + 'NE1': 8, + 'CE2': 9, + 'CE3': 10, + 'CZ2': 11, + 'CZ3': 12, + 'CH2': 13}, + 'TYR': {'N': 0, + 'CA': 1, + 'C': 2, + 'O': 3, + 'CB': 4, + 'CG': 5, + 'CD1': 6, + 'CD2': 7, + 'CE1': 8, + 'CE2': 9, + 'CZ': 10, + 'OH': 11}, + 'VAL': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG1': 5, 'CG2': 6}, + 'UNK': {}, + 'WAT': {'OW': 0}} + +atom14_to_atom37_dict = {'ALA': [0, 1, 2, 4, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'ARG': [0, 1, 2, 4, 3, 5, 11, 23, 32, 29, 30, 0, 0, 0], + 'ASN': [0, 1, 2, 4, 3, 5, 16, 15, 0, 0, 0, 0, 0, 0], + 'ASP': [0, 1, 2, 4, 3, 5, 16, 17, 0, 0, 0, 0, 0, 0], + 'CYS': [0, 1, 2, 4, 3, 10, 0, 0, 0, 0, 0, 0, 0, 0], + 'GLN': [0, 1, 2, 4, 3, 5, 11, 26, 25, 0, 0, 0, 0, 0], + 'GLU': [0, 1, 2, 4, 3, 5, 11, 26, 27, 0, 0, 0, 0, 0], + 'GLY': [0, 1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'HIS': [0, 1, 2, 4, 3, 5, 14, 13, 20, 25, 0, 0, 0, 0], + 'HIE': [0, 1, 2, 4, 3, 5, 14, 13, 20, 25, 0, 0, 0, 0], + 'ILE': [0, 1, 2, 4, 3, 6, 7, 12, 0, 0, 0, 0, 0, 0], + 'LEU': [0, 1, 2, 4, 3, 5, 12, 13, 0, 0, 0, 0, 0, 0], + 'LYS': [0, 1, 2, 4, 3, 5, 11, 19, 35, 0, 0, 0, 0, 0], + 'MET': [0, 1, 2, 4, 3, 5, 18, 19, 0, 0, 0, 0, 0, 0], + 'PHE': [0, 1, 2, 4, 3, 5, 12, 13, 20, 21, 32, 0, 0, 0], + 'PRO': [0, 1, 2, 4, 3, 5, 11, 0, 0, 0, 0, 0, 0, 0], + 'SER': [0, 1, 2, 4, 3, 8, 0, 0, 0, 0, 0, 0, 0, 0], + 'THR': [0, 1, 2, 4, 3, 9, 7, 0, 0, 0, 0, 0, 0, 0], + 'TRP': [0, 1, 2, 4, 3, 5, 12, 13, 24, 21, 22, 33, 34, 28], + 'TYR': [0, 1, 2, 4, 3, 5, 12, 13, 20, 21, 32, 31, 0, 0], + 'VAL': [0, 1, 2, 4, 3, 6, 7, 0, 0, 0, 0, 0, 0, 0], + 'UNK': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'WAT': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} + + +def read_pdb(pdb_name, ignoreh=False): + """Read a pdb file and return atom information with numpy array format. + Args: + pdb_name(str): The pdb file name, absolute path is suggested. + Returns: + atom_names(list): 1-dimension list contain all atom names in each residue. + res_names(list): 1-dimension list of all residue names. + res_ids(numpy.int32): Unique id for each residue names. + crds(list): The list format of coordinates. + res_pointer(numpy.int32): The pointer where the residue starts. + flatten_atoms(numpy.str_): The flatten atom names. + flatten_crds(numpy.float32): The numpy array format of coordinates. + init_res_names(list): The residue name information of each atom. + init_res_ids(list): The residue id of each atom. + """ + with open(pdb_name, 'r', encoding="utf-8") as pdb: + lines = pdb.readlines() + atom_names = [] + atom_group = [] + res_names = [] + res_ids = [] + init_res_names = [] + init_res_ids = [] + crds = [] + crd_group = [] + res_pointer = [] + flatten_atoms = [] + flatten_crds = [] + atom14_positions = [] + atom14_atom_exists = [] + residx_atom14_to_atom37 = [] + for index, line in enumerate(lines): + if 'END' in line or 'TER' in line: + atom_names.append(atom_group) + crds.append(crd_group) + atom14_positions.append(atom_pos) + residx_atom14_to_atom37.append(atom14_to_atom37_dict[res_name]) + break + if not line.startswith('ATOM'): + continue + atom_name = line[12:16].strip() + if ignoreh and atom_name.startswith('H'): + continue + res_name = line[17:20].strip() + res_id = int(line[22:26].strip()) + crd = [float(line[30:38]), + float(line[38:46]), + float(line[46:54])] + pointer = int(line[6:11].strip()) - 1 + flatten_atoms.append(atom_name) + flatten_crds.append(crd) + init_res_names.append(res_name) + init_res_ids.append(res_id) + if not res_ids: + res_ids.append(res_id) + res_names.append(res_name) + atom14_atom_exists.append(restype_name_to_atom14_masks[res_name]) + atom_group.append(atom_name) + crd_group.append(crd) + res_pointer.append(0) + atom_pos = np.zeros((14, 3)) + if not atom_name.startswith('H') and atom_name != 'OXT': + atom_pos[atom14_order_dict[res_name] + [atom_name]] = np.array(crd) + elif res_id != res_ids[-1]: + atom14_positions.append(atom_pos) + residx_atom14_to_atom37.append(atom14_to_atom37_dict[res_name]) + atom_pos = np.zeros((14, 3)) + if not atom_name.startswith('H') and atom_name != 'OXT': + atom_pos[atom14_order_dict[res_name] + [atom_name]] = np.array(crd) + atom_names.append(atom_group) + crds.append(crd_group) + atom_group = [] + crd_group = [] + res_ids.append(res_id) + res_names.append(res_name) + atom14_atom_exists.append(restype_name_to_atom14_masks[res_name]) + atom_group.append(atom_name) + crd_group.append(crd) + res_pointer.append(pointer) + else: + atom_group.append(atom_name) + crd_group.append(crd) + if not atom_name.startswith('H') and atom_name != 'OXT': + atom_pos[atom14_order_dict[res_name] + [atom_name]] = np.array(crd) + if index == len(lines) - 1: + atom_names.append(atom_group) + crds.append(crd_group) + atom14_positions.append(atom_pos) + residx_atom14_to_atom37.append(atom14_to_atom37_dict[res_name]) + + res_ids = np.array(res_ids, np.int32) + flatten_atoms = np.array(flatten_atoms, np.str_) + flatten_crds = np.array(flatten_crds, np.float32) + init_res_names = np.array(init_res_names) + init_res_ids = np.array(init_res_ids, np.int32) + res_pointer = np.array(res_pointer, np.int32) + # Violation loss parameters + residue_index = np.arange(res_pointer.shape[0]) + aatype = np.zeros_like(residue_index) + for i in range(res_pointer.shape[0]): + aatype[i] = resdict[res_names[i]] + atom14_atom_exists = np.array(atom14_atom_exists, np.float32) + + atom14_positions = np.array(atom14_positions, np.float32) + residx_atom14_to_atom37 = np.array(residx_atom14_to_atom37, np.float32) + + return atom_names, res_names, res_ids, crds, res_pointer, flatten_atoms, flatten_crds, init_res_names,\ + init_res_ids,\ + residue_index, aatype, atom14_positions, atom14_atom_exists, residx_atom14_to_atom37 diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..347c0b9878a6e2137461d726ac3a88e7c55c6091 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Molecules +""" + +from .molecule import Molecule +from .protein import Protein + +__all__ = ['Molecule', 'Protein'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/molecule.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/molecule.py new file mode 100644 index 0000000000000000000000000000000000000000..491dc9ba9d975dfe9cc7a2847c982df8c5b6efec --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/molecule.py @@ -0,0 +1,875 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Molecule +""" + +import copy +import itertools +from typing import Union, Tuple +import numpy as np +from numpy import ndarray +import mindspore as ms +from mindspore import Parameter +from mindspore import ops +from mindspore.ops import functional as F +from mindspore.nn import Cell +from mindspore.common import Tensor +from mindspore import numpy as msnp + +from ..residue import Residue +from ...data.template import get_molecule +from ...function import functions as func +from ...function.units import Units, global_units +from ...function.functions import get_ndarray + + +class Molecule(Cell): + r""" + Cell for molecular system. + + Args: + atoms (list): Atoms in system. Can be list of str or int. Default: None. + atom_name (list): Atom name. Can be ndarray or list of str. Default: None. + atom_type (list): Atom type. Can be ndarray or list of str. Default: None. + atom_mass (Tensor): Tensor of shape (B, A). Data type is float. + Atom mass. Default: None. + atom_charge (Tensor): Tensor of shape (B, A). Data type is float. + Atom charge. Default: None. + atomic_number (Tensor): Tensor of shape (B, A). Data type is float. + Atomic number. Default: None. + bond (Tensor): Tensor of shape (B, b, 2) or (1, b, 2). Data type is int. + Bond index. Default: None. + coordinate (Tensor): Tensor of shape (B, A, D) or (1, A, D). Data type is float. + Position coordinates of atoms. Default: None. + pbc_box (Tensor): Tensor of shape (B, D) or (1, D). Data type is float. + Box of periodic boundary condition. Default: None. + template (Union[dict, str]): Template of residue. + The key of the dict are base, template, the name of molecule and so on. + The value of the dict is file name. + Default: None. + residue (Union[Residue, list]): Residue parameter. Default: None. + length_unit (str): Length unit for position coordinates. Default: None. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + b: Number of bonds. + D: Dimension of the simulation system. Usually is 3. + """ + + def __init__(self, + atoms: list = None, + atom_name: list = None, + atom_type: list = None, + atom_mass: Tensor = None, + atom_charge: Tensor = None, + atomic_number: Tensor = None, + bond: Tensor = None, + coordinate: Tensor = None, + pbc_box: Tensor = None, + template: Union[dict, str] = None, + residue: Union[Residue, list] = None, + length_unit: str = None, + ): + + super().__init__() + + if length_unit is None: + self.units = global_units + else: + self.units = Units(length_unit) + + if template is not None: + molecule, template = get_molecule(template) + residue: list = [] + for res in molecule.get('residue'): + residue.append(Residue(name=res, template=template)) + if coordinate is None: + coordinate = np.array(molecule.get('coordinate'), np.float32) + coordinate *= self.units.convert_length_from(molecule.get('length_unit')) + + self.num_residue = 1 + if residue is None or not residue: + if atoms is not None: + atoms = get_ndarray(atoms) + if np.issubdtype(atoms.dtype, np.integer): + if atomic_number is None: + atomic_number = atoms + elif np.issubdtype(atoms.dtype, np.character): + if atom_name is None: + atom_name = atoms + else: + raise TypeError( + 'The dtype of atoms must be integer of character!') + + if atom_name is not None or atomic_number is not None: + residue = Residue( + atom_name=atom_name, + atom_type=atom_type, + atom_mass=atom_mass, + atom_charge=atom_charge, + atomic_number=atomic_number, + bond=bond, + ) + + self.residue = None + self.num_residue = 0 + if residue is not None: + if isinstance(residue, list): + self.residue = residue + elif isinstance(residue, Residue): + self.residue = [residue] + else: + raise ValueError( + 'The type of residue must be Residue or list but got: '+str(type(residue))) + + # The number of multi_system of system + self.multi_system = 1 + # A: number of atoms + self.num_atoms = 0 + + # (B,A) + self.atom_name = None + self.atom_type = None + self.atom_mass = None + self.atom_mask = None + self.atomic_number = None + self.inv_mass = None + self.atom_charge = None + + # (B,R) + self.residue_mass = None + self.residue_name = None + self.res_natom_tensor = None + # (R) + self.residue_pointer = None + # (A) + self.atom_resid = None + self.image_index = None + + # (B,C,2) + self.bond = None + self.hydrogen_bond = None + # (B,C): bond length for constraint + self.bond_length = None + + # (B,A,D) + self.coordinate = None + # (B,D) + self.pbc_box = None + + self.dimension = None + self.num_walker = None + self.degrees_of_freedom = None + # (B,1) + self.system_mass = None + self.has_empty_atom = None + self.system_natom = None + + self.use_pbc = False + self.num_com = None + self.image = None + + self.build_system() + if self.residue is not None: + self.build_space(coordinate, pbc_box) + + @property + def length_unit(self): + return self.units.length_unit + + def _check_pbc_box(self, pbc_box: Tensor): + """check PBC box.""" + pbc_box = Tensor(pbc_box, ms.float32) + if pbc_box.ndim == 1: + pbc_box = F.expand_dims(pbc_box, 0) + if pbc_box.ndim != 2: + raise ValueError('The rank of pbc_box must be 1 or 2!') + if pbc_box.shape[-1] != self.dimension: + raise ValueError('The last dimension of "pbc_box" ('+str(pbc_box.shape[-1]) + + ') must be equal to the dimension of "coordinate" ('+str(self.dimension)+')!') + if pbc_box.shape[0] > 1 and pbc_box.shape[0] != self.num_walker: + raise ValueError('The first dimension of "pbc_box" ('+str(pbc_box.shape[0]) + + ') does not match the first dimension of "coordinate" ('+str(self.dimension)+')!') + return Parameter(pbc_box, name='pbc_box', requires_grad=True) + + def move(self, shift: Tensor = None): + """ + Move the coordinate of the system. + + Args: + shift (Tensor): Shift parameter. Default: None. + """ + if shift is not None: + self.update_coordinate(self.coordinate + Tensor(shift, ms.float32)) + return self + + def copy(self, shift: Tensor = None): + """ + Return a Molecule that copy the parameters of this molecule. + + Args: + shift (Tensor): Shift parameter. Default: None. + """ + coordinate = self.get_coordinate() + if shift is not None: + coordinate += Tensor(shift, ms.float32) + return Molecule( + residue=copy.deepcopy(self.residue), + coordinate=coordinate, + pbc_box=self.get_pbc_box(), + length_unit=self.length_unit, + ) + + def add_residue(self, residue: Residue, coordinate: Tensor = None): + """ + Add residue. + + Args: + residue (Union[Residue, list]): Residue parameter. + coordinate (Tensor): Tensor of shape (B, A, D) or (1, A, D). Data type is float. + Position coordinates of atoms. Default: None. + """ + if not isinstance(residue, list): + if isinstance(residue, Residue): + residue = [residue] + else: + raise TypeError('The type of residue must be Residue or list but got: ' + + str(type(residue))) + + self.residue.extend(copy.deepcopy(residue)) + self.build_system() + if coordinate is None: + natoms = 0 + for res in residue: + natoms += res.num_atoms + coordinate = msnp.ones((self.num_walker, natoms, self.dimension), ms.float32) + + coordinate = msnp.concatenate((self.coordinate, coordinate), axis=-2) + self.build_space(coordinate, self.pbc_box) + return self + + def append(self, system): + """ + Append the system. + + Args: + system (Molecule): System parameter. + """ + if not isinstance(system, Molecule): + raise TypeError('For add, the type of system must be "Molecule" but got: ' + + str(type(system))) + self.add_residue(system.residue, system.get_coordinate()) + return self + + def reduplicate(self, shift: Tensor): + """ + Duplicate the system to double of the origin size. + + Args: + shift (Tensor): Shift parameter. Default: Tensor. + """ + shift = Tensor(shift, ms.float32) + self.residue.extend(copy.deepcopy(self.residue)) + self.build_system() + coordinate = msnp.concatenate((self.coordinate, self.coordinate+shift), axis=-2) + self.build_space(coordinate, self.pbc_box) + return self + + def build_atom_type(self): + """build atom type.""" + atom_type = () + for i in range(self.num_residue): + atom_type += (self.residue[i].atom_type,) + self.atom_type = np.concatenate(atom_type, axis=-1) + return self + + def build_atom_charge(self): + """build atom charge.""" + charges = [] + for i in range(self.num_residue): + charges.append(self.residue[i].atom_charge is not None) + + if any(charges): + atom_charge = () + for i in range(self.num_residue): + if self.residue[i].atom_charge is None: + atom_charge += (msnp.zeros_like(self.residue[i].atom_mass),) + else: + atom_charge += (self.residue[i].atom_charge,) + self.atom_charge = msnp.concatenate(atom_charge, axis=-1) + return self + + def build_system(self): + """build the system by residues.""" + if self.residue is None: + self.residue = None + return self + + self.num_residue = len(self.residue) + multi_system = [] + charges = [] + for i in range(self.num_residue): + multi_system.append(self.residue[i].multi_system) + charges.append(self.residue[i].atom_charge is not None) + multi_system = list(set(multi_system)) + if len(multi_system) == 1: + self.multi_system = multi_system[0] + elif len(multi_system) == 2 and (multi_system[0] == 1 or multi_system[1] == 1): + self.multi_system = max(multi_system) + else: + raise ValueError( + 'The multi_system of residues cannot be broadcast: '+str(multi_system)) + + any_charge = any(charges) + + atom_name = () + atom_type = () + atom_mass = () + atom_mask = () + atom_charge = () + atomic_number = () + inv_mass = () + + atom_resid = () + image_index = () + + residue_mass = () + res_natom_tensor = () + + bond = () + head_atom = None + tail_atom = None + + pointer = 0 + residue_pointer = [] + residue_name = [] + + for i in range(self.num_residue): + if self.residue[i].multi_system != self.multi_system: + self.residue[i].broadcast_multiplicity(self.multi_system) + + self.residue[i].set_start_index(pointer) + residue_pointer.append(pointer) + residue_name.append(self.residue[i].name) + + # (A') + atom_resid += (msnp.full((self.residue[i].num_atoms,), i, ms.int32),) + image_index += (msnp.full((self.residue[i].num_atoms,), pointer, ms.int32),) + + # (B,A') + atom_name += (self.residue[i].atom_name,) + atom_type += (self.residue[i].atom_type,) + atom_mass += (self.residue[i].atom_mass,) + atom_mask += (self.residue[i].atom_mask,) + atomic_number += (self.residue[i].atomic_number,) + inv_mass += (self.residue[i].inv_mass,) + if any_charge: + if self.residue[i].atom_charge is None: + atom_charge += (msnp.zeros_like( + self.residue[i].atom_mass),) + else: + atom_charge += (self.residue[i].atom_charge,) + + # (B,1) + residue_mass += (self.residue[i].total_mass,) + res_natom_tensor += (self.residue[i].natom_tensor,) + + # (B,1) + head_atom = self.residue_head(i) + if head_atom is not None: + if tail_atom is None: + print('Warrning! The head_atom of residue '+str(i)+' is not None' + + ' but the tail_atom of residue '+str(i-1)+' is None. ') + else: + # (B,1,2) + connect = msnp.concatenate( + (F.expand_dims(tail_atom, -2), F.expand_dims(head_atom, -2)), axis=-1) + bond += (connect,) + # (B,1,1) + tail_atom = self.residue_tail(i) + + # (B,C',2) + if self.residue[i].bond is not None: + bond += (self.residue[i].bond + pointer,) + + pointer += self.residue[i].num_atoms + + self.num_atoms = pointer + self.residue_pointer = Tensor(residue_pointer, ms.int32) + self.residue_name = np.array(residue_name, np.str_) + + # (B,A) + self.atom_name = np.concatenate(atom_name, axis=-1) + self.atom_type = np.concatenate(atom_type, axis=-1) + self.atom_mass = msnp.concatenate(atom_mass, axis=-1) + self.atom_mask = msnp.concatenate(atom_mask, axis=-1) + self.atomic_number = msnp.concatenate(atomic_number, axis=-1) + self.inv_mass = msnp.concatenate(inv_mass, axis=-1) + self.atom_charge = None + if any_charge: + self.atom_charge = msnp.concatenate(atom_charge, axis=-1) + + # (A) + self.atom_resid = msnp.concatenate(atom_resid) + self.image_index = msnp.concatenate(image_index) + + # (B,R) + self.residue_mass = msnp.concatenate(residue_mass, axis=-1) + self.res_natom_tensor = msnp.concatenate(res_natom_tensor, axis=-1) + + # (B,C,2) + self.bond = None + if bond: + self.bond = msnp.concatenate(bond, -2) + + return self + + def build_space(self, coordinate: Tensor, pbc_box: Tensor = None): + """ + Build coordinate and PBC box. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D) or (1, A, D). Data type is float. + Position coordinates of atoms. + pbc_box (Tensor): Tensor of shape (B, D) or (1, D). Data type is float. + Box of periodic boundary condition. Default: None. + """ + # (B,A,D) + if coordinate is None: + coordinate = np.random.uniform(0, self.units.length( + 1, 'nm'), size=(self.multi_system, self.num_atoms, 3)) + coordinate = Tensor(coordinate, ms.float32) + coordinate = self._check_coordianate(coordinate) + self.coordinate = Parameter(coordinate, name='coordinate') + self.dimension = self.coordinate.shape[-1] + self.num_walker = self.coordinate.shape[0] + + # (B,1) + self.system_mass = msnp.sum(self.atom_mass, -1, keepdims=True) + self.has_empty_atom = (not self.atom_mask.all()) + # (B,1) <- (B,A) + self.system_natom = msnp.sum(F.cast(self.atom_mask, ms.float32), -1, keepdims=True) + + self.keep_prod = ops.ReduceProd(keep_dims=True) + self.identity = ops.Identity() + + # (B,D) + if pbc_box is None: + self.pbc_box = None + self.use_pbc = False + self.num_com = self.dimension + self.image = None + else: + self.use_pbc = True + self.num_com = self.dimension + pbc_box = Tensor(pbc_box, ms.float32) + if pbc_box.ndim == 1: + pbc_box = F.expand_dims(pbc_box, 0) + if pbc_box.ndim != 2: + raise ValueError('The rank of pbc_box must be 1 or 2!') + if pbc_box.shape[-1] != self.dimension: + raise ValueError('The last dimension of "pbc_box" ('+str(pbc_box.shape[-1]) + + ') must be equal to the dimension of "coordinate" ('+str(self.dimension)+')!') + if pbc_box.shape[0] > 1 and pbc_box.shape[0] != self.num_walker: + raise ValueError('The first dimension of "pbc_box" ('+str(pbc_box.shape[0]) + + ') does not match the first dimension of "coordinate" ('+str(self.dimension)+')!') + self.pbc_box = Parameter(pbc_box, name='pbc_box') + + self.image = Parameter(msnp.zeros_like(self.coordinate, ms.int32), name='coordinate_image', + requires_grad=False) + self.update_image() + + self.degrees_of_freedom = self.dimension * self.num_atoms - self.num_com + return self + + def set_bond_length(self, bond_length: Tensor): + """ + Set bond length. + + Args: + bond_length (Tensor): Length of bond. + """ + if self.bond is None: + raise ValueError('Cannot setup bond_length because bond is None') + bond_length = Tensor(bond_length, ms.float32) + if bond_length.shape != self.bond.shape[:2]: + raise ValueError('The shape of bond_length '+str(self.bond_length.shape) + + ' does not match the shape of bond '+str(self.bond.shape)) + self.bond_length = bond_length + return self + + def residue_index(self, res_id: int) -> Tensor: + """ + Get index of residue. + + Args: + res_id (int): Residue ID parameter. + + Returns: + Tensor, the index of residue. + """ + return self.residue[res_id].system_index + + def residue_bond(self, res_id: int) -> Tensor: + """ + Get bond index of residue. + + Args: + res_id (int): Residue ID parameter. + + Returns: + Tensor, the bond index of residue. + """ + if self.residue[res_id].bond is None: + return None + return self.residue[res_id].bond + self.residue[res_id].start_index + + def residue_head(self, res_id: int) -> Tensor: + """ + Get head index of residue. + + Args: + res_id (int): Residue ID parameter. + + Returns: + Tensor, the head index of residue. + """ + if self.residue[res_id].head_atom is None: + return None + return self.residue[res_id].head_atom + self.residue[res_id].start_index + + def residue_tail(self, res_id: int) -> Tensor: + """ + Get tail index of residue. + + Args: + res_id (int): Residue ID parameter. + + Returns: + Tensor, the tail index of residue. + """ + if self.residue[res_id].tail_atom is None: + return None + return self.residue[res_id].tail_atom + self.residue[res_id].start_index + + def residue_coordinate(self, res_id: int) -> Tensor: + """ + Get residue coordinate. + + Args: + res_id (int): Residue ID parameter. + + Returns: + Tensor, the residue coordinate. + """ + return F.gather_d(self.coordinate, -2, self.residue[res_id].system_index) + + def get_volume(self) -> Tensor: + """ + get volume of system. + + Returns: + Tensor, volume of system. + """ + if self.pbc_box is None: + return None + return self.keep_prod(self.pbc_box, -1) + + def space_parameters(self) -> list: + """ + get the parameter of space (coordinates and pbc box). + + Returns: + list, a list of parameter of space. + """ + if self.pbc_box is None: + return [self.coordinate] + return [self.coordinate, self.pbc_box] + + def trainable_params(self, recurse=True) -> list: + """ + Args: + recurse (bool, optional): Recurse parameter. Default: True. + + Returns: + list, a list of trainable_params. + """ + return list(filter(lambda x: x.name.split('.')[-1] == 'coordinate', self.get_parameters(expand=recurse))) + + def _check_coordianate(self, coordinate: Tensor) -> Tensor: + """ + check coordinate. + + Returns: + Tensor, a Tensor of coordinate. + """ + coordinate = Tensor(coordinate, ms.float32) + if coordinate.ndim == 2: + coordinate = F.expand_dims(coordinate, 0) + if coordinate.ndim != 3: + raise ValueError('The rank of "coordinate" must be 2 or 3!') + if coordinate.shape[-2] != self.num_atoms: + raise ValueError('The penultimate dimension of "coordinate" ('+str(coordinate.shape[-2]) + + ') must be equal to the number of atoms ('+str(self.num_atoms)+')!') + if self.multi_system > 1 and coordinate.shape[0] != self.multi_system: + raise ValueError('The first dimension of "coordinate" ('+str(coordinate.shape[0]) + + ') does not match the that of "atom_name" ('+str(self.multi_system)+')!') + return coordinate + + def update_coordinate(self, coordinate: Tensor, success: bool = True) -> bool: + """ + Update the parameter of coordinate. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D) or (1, A, D). Data type is float. + Position coordinates of atoms. + success (bool, optional): Success parameter. Default: True. + + Returns: + bool, whether update the parameter of coordinate. + """ + success = F.depend(success, F.assign(self.coordinate, coordinate)) + if self.pbc_box is not None: + success = self.update_image(success=success) + return success + + def set_coordianate(self, coordinate: Tensor): + """ + Set the value of coordinate. + + Args: + coordinate (Tensor): Tensor of shape (B, A, D) or (1, A, D). Data type is float. + Position coordinates of atoms. Default: None. + """ + coordinate = self._check_coordianate(coordinate) + if coordinate is not None and coordinate.shape == self.coordinate.shape: + self.update_coordinate(coordinate) + else: + self.coordinate = Parameter(coordinate, name='coordinate') + self.dimension = self.coordinate.shape[-1] + self.num_walker = self.coordinate.shape[0] + return self + + def update_pbc_box(self, pbc_box: Tensor, success: bool = True): + """ + Update PBC box. + + Args: + pbc_box (Tensor): Tensor of shape (B, D) or (1, D). Data type is float. + Box of periodic boundary condition. Default: None. + success (bool, optional): Success parameter. Default: True. + + Returns: + bool, whether update PBC box. + """ + success = F.depend(True, F.assign(self.pbc_box, pbc_box)) + if self.pbc_box is not None: + success = self.update_image(success=success) + return success + + def set_pbc_grad(self, grad_box: bool): + """ + Set whether to calculate the gradient of PBC box. + + Args: + grad_box (bool): Whether to calculate the gradient of PBC box. + """ + if self.pbc_box is not None: + self.pbc_box.requires_grad = grad_box + return self + + def set_pbc_box(self, pbc_box: Tensor = None): + """ + Set PBC box. + + Args: + pbc_box (Tensor): Tensor of shape (B, D) or (1, D). Data type is float. + Box of periodic boundary condition. Default: None. + """ + if pbc_box is None: + self.pbc_box = None + self.use_pbc = False + self.num_com = self.dimension + else: + self.use_pbc = True + self.num_com = self.dimension * 2 + if self.pbc_box is None: + self.pbc_box = self._check_pbc_box(pbc_box) + else: + if pbc_box.shape != self.pbc_box.shape: + raise ValueError('The shape of the new pbc_box '+str(pbc_box.shape) + + 'is not equal to the old one '+str(self.pbc_box)+'!') + self.update_pbc_box(pbc_box) + return self + + def repeat_box(self, lattices: list): + """ + Repeat the system according to the lattices of PBC box. + + Args: + lattices (list): Lattices parameter. + """ + if self.pbc_box is None: + raise RuntimeError('repeat_box() cannot be used without pbc_box, ' + 'please use set_pbc_box() to set pbc_box first ' + 'before using this function.') + + if isinstance(lattices, Tensor): + lattices = lattices.asnumpy() + if isinstance(lattices, ndarray): + lattices = lattices.tolist() + if not isinstance(lattices, list): + raise TypeError('The type of lattices must be list, ndarry or Tensor but got: ' + + str(type(lattices))) + if len(lattices) != self.dimension: + raise ValueError('The number of lattics ('+str(len(lattices))+') must be equal to ' + 'the dimension of system ('+str(self.dimension)+')') + product_ = [] + for l in lattices: + if l <= 0: + raise ValueError('The number in lattices must larger than 0!') + product_.append(list(range(l))) + + shift_num = tuple(itertools.product(*product_))[1:] + if shift_num: + shift_box = Tensor(shift_num, ms.float32) * self.pbc_box + box = self.copy() + coord = box.get_coordinate() + coordinate = (coord,) + for shift in shift_box: + self.residue.extend(copy.deepcopy(box.residue)) + coordinate += (coord+shift,) + + self.build_system() + coordinate = msnp.concatenate(coordinate, axis=-2) + self.build_space(coordinate, self.pbc_box) + new_box = Tensor(lattices, ms.int32) * self.pbc_box + self.update_pbc_box(new_box) + + return self + + def coordinate_in_box(self, shift: float = 0) -> Tensor: + """ + Get the coordinate in a whole PBC box. + + Args: + shift (float): Shift parameter. Default: 0. + + Returns: + Tensor, the coordinate in a whole PBC box. + """ + coordinate = self.identity(self.coordinate) + pbc_box = self.identity(self.pbc_box) + return func.displace_in_box(coordinate, pbc_box, shift) + + def calc_image(self, shift: float = 0) -> Tensor: + """ + Calculate the image of coordinate. + + Args: + shift (float): Shift parameter. Default: 0. + + Returns: + Tensor, a Tensor of the image of coordinate. + """ + coordinate = self.identity(self.coordinate) + pbc_box = self.identity(self.pbc_box) + image = func.periodic_image(coordinate, pbc_box, shift) + if self.image_index is not None: + image = image[:, self.image_index, :] + return image + + def update_image(self, image: Tensor = None, success: bool = True) -> bool: + """ + Update the image of coordinate. + + Args: + image (Tensor): Image parameter. Default: None. + success (bool, optional): Success parameter. Default: True. + + Returns: + bool. + """ + if image is None: + image = self.calc_image() + return F.depend(success, F.assign(self.image, image)) + + def set_length_unit(self, unit): + """ + Set the length unit of system. + + Args: + unit (Units): Units of length and energy. + """ + scale = self.units.convert_length_to(unit) + coordinate = self.coordinate * scale + self.update_coordinate(coordinate) + if self.pbc_box is not None: + pbc_box = self.pbc_box * scale + self.update_pbc_box(pbc_box) + self.units.set_length_unit(unit) + return self + + def get_coordinate(self) -> Tensor: + """ + get Tensor of coordinate. + + Returns: + Tensor, a Tensor of coordinate. + """ + return self.identity(self.coordinate) + + def get_pbc_box(self) -> Tensor: + """ + get Tensor of PBC box. + + Returns: + Tensor, a Tensor of PBC box. + """ + if self.pbc_box is None: + return None + return self.identity(self.pbc_box) + + def construct(self) -> Tuple[Tensor, Tensor]: + r""" + Get space information of system. + + Returns: + - coordinate (Tensor), Tensor of shape (B, A, D). Data type is float. + - pbc_box (Tensor), Tensor of shape (B, D). Data type is float. + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + """ + coordinate = self.identity(self.coordinate) + pbc_box = None + if self.pbc_box is not None: + pbc_box = self.identity(self.pbc_box) + return coordinate, pbc_box diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/protein.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..9750385f03d391e02b41ab4e1ec835ecf85cf85d --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/molecule/protein.py @@ -0,0 +1,124 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Protein modeling. +""" + +import numpy as np +from mindspore.common import Tensor +from .molecule import Molecule +from ..residue.amino import AminoAcid +from ..modeling.hadder import ReadPdbByMindsponge as read_pdb +from ...data.template import get_template + + +backbone_atoms = np.array(['N', 'CA', 'C', 'O'], np.str_) +include_backbone_atoms = np.array(['OXT'], np.str_) + + +class Protein(Molecule): + r""" + Protein molecule. + + Args: + pdb (str): Atoms in system. Can be list of str or int. Default: None. + sequence (list): Atom type. Can be ndarray or list of str. Default: None. + coordinate (Tensor): Tensor of shape (B, A, D) or (1, A, D). Data type is float. + Position coordinates of atoms. Default: None. + pbc_box (Tensor): Tensor of shape (B, D) or (1, D). Data type is float. + Box of periodic boundary condition. Default: None. + template (Union[dict, str]): Template of residue. + The key of the dict are base, template, the name of molecule and so on. + The value of the dict is file name. + Default: 'protein0.yaml' + ignore_hydrogen (bool, optional): Ignore hydrogen. Default: True. + length_unit (str): Length unit for position coordinates. Default: None. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + D: Dimension of the simulation system. Usually is 3. + """ + + def __init__(self, + pdb: str = None, + sequence: list = None, + coordinate: Tensor = None, + pbc_box: Tensor = None, + template: dict = 'protein0.yaml', + ignore_hydrogen: bool = True, + length_unit: str = None, + ): + + super().__init__(length_unit=length_unit) + + if pdb is None: + #TODO + if sequence is None: + raise ValueError('At least 1 of pdb name and residue sequence should be given.') + else: + _, residue_name, _, coordinate, residue_pointer, flatten_atoms, flatten_crds, init_res_names,\ + init_res_ids, \ + _, _, _, _, _ = read_pdb( + pdb, ignore_hydrogen) + + if len(residue_name) > 1: + if residue_name[0] != 'ACE' and residue_name[0] != 'NME': + residue_name[0] = 'N' + residue_name[0] + if residue_name[-1] != 'ACE' and residue_name[-1] != 'NME': + residue_name[-1] = 'C' + residue_name[-1] + + self.init_resname = init_res_names + self.init_resid = init_res_ids + num_residue = len(residue_name) + residue_pointer = np.append(residue_pointer, len(flatten_atoms)) + template = get_template(template) + + self.residue = [] + for i in range(num_residue): + name = residue_name[i] + atom_name = flatten_atoms[residue_pointer[i]: residue_pointer[i + 1]][None, :] + residue = AminoAcid(name=name, template=template, atom_name=atom_name) + self.residue.append(residue) + + coordinate = flatten_crds * self.units.convert_length_from('A') + + self.build_system() + self.build_space(coordinate, pbc_box) + + def get_head_atom(self, residue_index, this_atom_names): + if residue_index == 0: + return None + for index, atom in enumerate(this_atom_names[0]): + if atom == 'N': + return np.array([index], np.int32) + return self + + def get_tail_atom(self, this_atom_names): + for index, atom in enumerate(this_atom_names[0]): + if atom == 'C': + return np.array([index], np.int32) + return self diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/__init__.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2cda13452d0e2afdb1ef1d1117daa01b630f97d4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Residues +""" + +from .residue import Residue +from .amino import AminoAcid + +__all__ = ['Residue', 'AminoAcid'] diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/amino.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/amino.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3be1920d22850e2ce415d5115358db9257fd00 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/amino.py @@ -0,0 +1,57 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Molecule +""" +from mindspore import ms_class +from .residue import Residue + + +@ms_class +class AminoAcid(Residue): + r""" + Residue of amino acid. + + Args: + name (str): Name of the residue. Default: '' + template (dict or str): Template of Residue. Default: None + atom_name (list): Atom name. Can be ndarray or list of str. Default: None + start_index (int): The start index of the first atom in this residue. Default: 0 + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + def __init__(self, + name: str = '', + template: dict = None, + atom_name: str = None, + start_index: int = 0, + ): + + super().__init__( + atom_name=atom_name, + start_index=start_index, + name=(name.replace('HIE', 'HIS') if 'HIE' in name else name), + template=template, + ) diff --git a/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/residue.py b/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/residue.py new file mode 100644 index 0000000000000000000000000000000000000000..708ffcaea1b62204a5b57d5a62083337b7dd746c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/mindsponge1/system/residue/residue.py @@ -0,0 +1,582 @@ +# Copyright 2021-2022 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Residue +""" + +from operator import itemgetter +from typing import Union +import numpy as np +from numpy import ndarray +import mindspore as ms +from mindspore import numpy as msnp +from mindspore import ms_class +from mindspore.ops import functional as F +from mindspore.common import Tensor + +from ...function.functions import get_integer +from ...data.elements import elements, element_set, element_dict, atomic_mass +from ...data.template import get_template, get_template_index + + +@ms_class +class Residue: + r""" + Class for residue in molecule. + + Args: + atom_name (list): Atom name. Can be ndarray or list of str. Default: None. + atom_type (list): Atom type. Can be ndarray or list of str. Default: None. + atom_mass (Tensor): Tensor of shape (B, A). Data type is float. + Atom mass. Default: None. + atom_charge (Tensor): Tensor of shape (B, A). Data type is float. + Atom charge. Default: None. + atomic_number (Tensor): Tensor of shape (B, A). Data type is float. + Atomic number. Default: None. + bond (Tensor): Tensor of shape (B, b, 2) or (1, b, 2). Data type is int. + Bond index. Default: None. + head_atom (int): Index of the head atom to connect with the previous residue. + Default: None. + tail_atom (int): Index of the tail atom to connect with the next residue. + Default: None. + start_index (int): The start index of the first atom in this residue. + name (str): Name of the residue. + Examples: 'SOL', 'CL'. Indicating water molecule and Na+ ion respectively. + The residue that is not defined usually called 'MOL'. + Default: 'MOL'. + template (Union[dict, str]): Template of residue. + The key of the dict are base, template, the name of molecule and so on. + The value of the dict is file name. + Default: None. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Symbols: + B: Batchsize, i.e. number of walkers in simulation. + A: Number of atoms. + b: Number of bonds. + """ + + def __init__(self, + atom_name: list = None, + atom_type: list = None, + atom_mass: Tensor = None, + atom_charge: Tensor = None, + atomic_number: Tensor = None, + bond: Tensor = None, + head_atom: int = None, + tail_atom: int = None, + start_index: int = 0, + name: str = 'MOL', + template: Union[dict, str] = None, + ): + + self._name = name + + self.atom_name = None + if atom_name is not None: + self.atom_name = np.array(atom_name, np.str_) + if self.atom_name.ndim == 1: + self.atom_name = np.expand_dims(self.atom_name, 0) + if self.atom_name.ndim != 2: + raise ValueError('The rank of "atom_name" must be 1 or 2!') + + if template is not None: + template = get_template(template) + if self._name is None: + if len(template) == 1: + self._name = list(template.keys())[0] + template = template.get(self._name) + else: + raise ValueError('The name cannot be None when the number of ' + 'keys in template is larger than 1!') + elif self._name not in template.keys(): + raise ValueError('Cannot found the key "' + str(self._name) + ' in template."') + + template = template.get(self._name) + + atom_mass = np.array(template.get('atom_mass'), np.float32) + atomic_number = np.array(template.get('atom_mass'), np.int32) + + atom_type = template.get('atom_type') + if atom_type is not None: + atom_type = np.array(atom_type, np.str_) + + atom_charge = template.get('atom_charge') + if atom_charge is not None: + atom_charge = np.array(atom_charge, np.float32) + + bond = template.get('bond') + if bond is not None: + bond = np.array(bond, np.int32) + + head_atom = template.get('head_atom') + tail_atom = template.get('tail_atom') + + if self.atom_name is None: + self.atom_name = np.array(template.get('atom_name'), np.str_).reshape(1, -1) + else: + atom_index = get_template_index(template, self.atom_name) + atom_mass = atom_mass[atom_index] + atomic_number = atomic_number[atom_index] + + if atom_type is not None: + atom_type = atom_type[atom_index] + + if atom_charge is not None: + atom_charge = atom_charge[atom_index] + + if bond is not None: + bond = self._get_bond(template, atom_index) + + serial_list: list = atom_index.reshape(-1).tolist() + + if head_atom is not None: + head_atom = serial_list.index(head_atom) + + if tail_atom is not None: + tail_atom = serial_list.index(tail_atom) + + if self.atom_name is None and atomic_number is None: + raise ValueError('atom_name and atomic_number cannot both be None') + + if atomic_number is not None: + self.atomic_number = Tensor(atomic_number, ms.int32) + if self.atomic_number.ndim == 1: + self.atomic_number = F.expand_dims(self.atomic_number, 0) + if self.atomic_number.ndim != 2: + raise ValueError('The rank of "atomic_number" must be 1 or 2!') + + if self.atom_name is None: + self.atom_name = np.array(elements[self.atomic_number.asnumpy()], np.str_) + + if atomic_number is None: + atom_name_list = self.atom_name.reshape(-1).tolist() + if set(atom_name_list) - element_set: + self.atomic_number = msnp.ones(self.atom_name.shape, ms.int32) + else: + atomic_number = itemgetter(*atom_name_list)(element_dict) + self.atomic_number = Tensor(atomic_number, ms.int32).reshape(self.atom_name.shape) + + if self.atomic_number.shape != self.atom_name.shape: + if self.atomic_number.shape[-1] == self.atom_name.shape[-1]: + if self.atomic_number.shape[0] == 1: + self.atomic_number = msnp.broadcast_to(self.atomic_number, self.atom_name.shape) + elif self.atom_name.shape[0] == 1: + self.atom_name = msnp.broadcast_to(self.atom_name, self.atomic_number.shape) + + raise ValueError('The shape of "atomic_number" ' + str(self.atomic_number) + + ' does not match the shape of "atom_name" ' + str(self.atom_name) + '!') + + if atom_type is None: + self.atom_type = self.atom_name.copy() + else: + self.atom_type = np.array(atom_type) + if self.atom_type.ndim == 1: + self.atom_type = np.expand_dims(self.atom_type, 0) + if self.atom_type.shape != self.atom_name.shape: + raise ValueError('The shape of "atom_type" ' + str(self.atom_type.shape) + + ' must be equal to the shape of "atom_name" ' + str(self.atom_name.shape) + '!') + + self.num_atoms = self.atom_name.shape[-1] + self.multi_system = self.atom_name.shape[0] + + self.start_index = get_integer(start_index) + # (A'') + self._index = msnp.arange(self.num_atoms) + self.system_index = self._index + start_index + + # (1,A') or (B,A') + if atom_mass is None: + if atomic_number is None: + self.atom_mass = msnp.ones( + self.atom_name.shape, dtype=np.float32) + else: + self.atom_mass = Tensor( + atomic_mass[self.atomic_number.asnumpy()], ms.float32) + else: + self.atom_mass = Tensor(atom_mass, ms.float32) + if self.atom_mass.ndim == 1: + self.atom_mass = F.expand_dims(self.atom_mass, 0) + if self.atom_mass.ndim != 2: + raise ValueError('The rank of "atom_mass" must be 1 or 2!') + if self.atom_mass.shape[-1] != self.num_atoms: + raise ValueError('The last dimension of atom_mass (' + str(self.atom_mass.shape[-1]) + + ') must be equal to the number of atoms (' + str(self.num_atoms) + ')!') + if self.atom_mass.shape[0] > 1 and self.atom_mass.shape[0] != self.multi_system: + raise ValueError('The first dimension of atom_mass (' + str(self.atom_mass.shape[0]) + + ') does not match the number of the number of system multi_system (' + + str(self.multi_system) + ')!') + + # (B,A') + self.atom_mask = F.logical_and( + self.atomic_number > 0, self.atom_mass > 0) + self.inv_mass = msnp.where( + self.atom_mask, msnp.reciprocal(self.atom_mass), 0) + # (B,1) + self.natom_tensor = msnp.sum( + F.cast(self.atom_mask, ms.float32), -1, keepdims=True) + self.total_mass = msnp.sum(self.atom_mass, -1, keepdims=True) + + # (B,A') + self.atom_charge = atom_charge + if atom_charge is not None: + self.atom_charge = Tensor(atom_charge, ms.float32) + if self.atom_charge.ndim == 1: + self.atom_charge = F.expand_dims(self.atom_charge, 0) + if self.atom_charge.ndim != 2: + raise ValueError('The rank of "atom_charge" must be 1 or 2!') + if self.atom_charge.shape[-1] != self.num_atoms: + raise ValueError('The last dimension of atom_charge (' + str(self.atom_charge.shape[-1]) + + ') must be equal to the num_atoms (' + str(self.num_atoms) + ')!') + if self.atom_charge.shape[0] != self.multi_system and self.atom_charge.shape[0] != 1: + raise ValueError('The first dimension of atom_charge (' + str(self.atom_charge.shape[0]) + + ') must be equal to 1 or the number of the number of system multi_system (' + + str(self.multi_system) + ')!') + + # (B,C,2) + self.bond = bond + self.bond_mask = None + if bond is not None: + self.bond = Tensor(bond, ms.int32) + if self.bond.shape[-1] != 2: + raise ValueError('The last dimension of bond must 2!') + if self.bond.ndim == 2: + self.bond = F.expand_dims(self.bond, 0) + self.bond_mask = self.bond < self.num_atoms + + # (B,1) + self.head_atom = head_atom + if head_atom is not None: + self.head_atom = Tensor([head_atom,], ms.int32).reshape(-1, 1) + if self.head_atom.shape[0] != self.multi_system and self.head_atom.shape[0] != 1: + raise ValueError('The first dimension of head_atom (' + str(self.head_atom.shape[0]) + + ') does not match the number of system multi_system (' + str(self.multi_system) + ')!') + if (self.head_atom >= self.num_atoms).any(): + raise ValueError( + 'The value of head_atom has exceeds the number of atoms.') + + # (B,1) + self.tail_atom = tail_atom + if tail_atom is not None: + self.tail_atom = Tensor([tail_atom,], ms.int32).reshape(-1, 1) + if self.tail_atom.shape[0] != self.multi_system and self.tail_atom.shape[0] != 1: + raise ValueError('The first dimension of tail_atom (' + str(self.tail_atom.shape[0]) + + ') does not match the number of system multi_system (' + str(self.multi_system) + ')!') + if (self.tail_atom >= self.num_atoms).any(): + raise ValueError( + 'The value of tail_atom has exceeds the number of atoms.') + + @property + def name(self) -> str: + return str(self._name) + + @classmethod + def _get_atom_mass(cls, template: dict, atom_index: ndarray = None) -> ndarray: + """get atom mass from template and atom index""" + atom_mass = np.array(template.get('atom_mass'), np.float32) + if atom_index is not None: + atom_mass = atom_mass[atom_index] + return atom_mass + + @classmethod + def _get_atomic_number(cls, template: dict, atom_index: ndarray = None) -> ndarray: + """get atomic number from template and atom index""" + atomic_number = np.array(template.get('atomic_number'), np.int32) + if atom_index is not None: + atomic_number = atomic_number[atom_index] + return atomic_number + + @classmethod + def _get_atom_type(cls, template: dict, atom_index: ndarray = None) -> ndarray: + """get atom type from template and atom index""" + atom_type = np.array(template.get('atom_type'), np.str_) + if atom_index is not None: + atom_type = atom_type[atom_index] + return atom_type + + @classmethod + def _get_atom_charge(cls, template: dict, atom_index: ndarray = None) -> ndarray: + """get atom charge from template and atom index""" + atom_charge = np.array(template['atom_charge'], np.float32) + if atom_index is not None: + atom_charge = atom_charge[atom_index] + return atom_charge + + @classmethod + def _get_bond(cls, template: dict, atom_index: ndarray = None) -> ndarray: + """get bond from template and atom index""" + bond = np.array(template.get('bond')) + if atom_index is not None: + bond_list = bond.reshape(-1).tolist() + if atom_index.ndim == 2 and atom_index.shape[0] > 1: + bond_ = [] + for serial in atom_index: + serial: list = serial.tolist() + b = np.array([serial.index(idx) + for idx in bond_list]).reshape(bond.shape) + bond_.append(b) + bond = np.stack(bond_, axis=0) + else: + serial: list = atom_index.reshape(-1).tolist() + bond = np.array([serial.index(idx) for idx in bond_list]).reshape(bond.shape) + return bond + + def build_atom_mass(self, template: dict): + """ + This function is built to attach the mass of atom to the index of atom. + + Args: + template (Union[dict, str]): Template of residue. + The key of the dict are base, template, the name of molecule and so on. + The value of the dict is file name. + Default: None. + """ + atom_index = get_template_index(template, self.atom_name) + self.atom_mass = Tensor(self._get_atom_mass(template, atom_index), ms.float32) + return self + + def build_atomic_number(self, template: dict): + """ + This function is built to attach the atomic number of atom to the index of atom. + + Args: + template (Union[dict, str]): Template of residue. + The key of the dict are base, template, the name of molecule and so on. + The value of the dict is file name. + Default: None. + """ + atom_index = get_template_index(template, self.atom_name) + self.atomic_number = Tensor(self._get_atomic_number(template, atom_index), ms.int32) + return self + + def build_atom_type(self, template: dict): + """ + This function is built to attach the type of atom to the index of atom. + + Args: + template (Union[dict, str]): Template of residue. + The key of the dict are base, template, the name of molecule and so on. + The value of the dict is file name. + Default: None. + """ + atom_index = get_template_index(template, self.atom_name) + self.atom_type = self._get_atom_type(template, atom_index) + return self + + def build_atom_charge(self, template: dict): + """ + This function is built to attach the chargre of atom to the index of atom. + + Args: + template (Union[dict, str]): Template of residue. + The key of the dict are base, template, the name of molecule and so on. + The value of the dict is file name. + Default: None. + """ + atom_index = get_template_index(template, self.atom_name) + self.atom_charge = Tensor(self._get_atom_charge(template, atom_index), ms.float32) + return self + + def build_bond(self, template: dict): + """ + This function is built to attach the bonds of atom to the index of atom. + + Args: + template (Union[dict, str]): Template of residue. + The key of the dict are base, template, the name of molecule and so on. + The value of the dict is file name. + Default: None. + """ + atom_index = get_template_index(template, self.atom_name) + self.bond = Tensor(self._get_bond(template, atom_index), ms.int32) + return self + + def add_atom(self, + atom_name: str = None, + atom_type: str = None, + atom_mass: float = None, + atom_charge: float = None, + atomic_number: str = None, + ): + """ + Set atom. + + Args: + atom_name (Union[numpy.ndarray, list(str)]): Atom name. Can be ndarray or list of str. Default: None. + atom_type (Union[numpy.ndarray, list(str)]): Atom type. Can be ndarray or list of str. Default: None. + atom_mass (Tensor): Tensor of shape (B, A). Data type is float. + Atom mass. Default: None. + atom_charge (Tensor): Tensor of shape (B, A). Data type is float. + Atom charge. Default: None. + atomic_number (Tensor): Tensor of shape (B, A). Data type is float. + Atomic number. Default: None. + """ + + if atom_name is None and atomic_number is None: + raise ValueError('atom_name and atomic_number cannot both be None') + + shape = (self.multi_system, 1) + + if atom_name is not None: + atom_name = np.array(atom_name, np.str_) + atom_name = np.broadcast_to(atom_name, shape) + + if atomic_number is not None: + atomic_number = Tensor(atomic_number, ms.int32) + atomic_number = msnp.broadcast_to(atomic_number, shape) + + if atom_name is None: + atom_name = elements[atomic_number.asnumpy()] + + if atom_mass is None: + if atomic_number is None: + atom_mass = msnp.ones(atom_name.shape, dtype=np.float32) + else: + atom_mass = Tensor( + atomic_mass[atomic_number.asnumpy()], ms.float32) + else: + atom_mass = Tensor(atom_mass, ms.float32) + atom_mass = np.broadcast_to(atom_mass, shape) + + if atomic_number is None: + atom_name_list = atom_name.reshape(-1).tolist() + if set(atom_name_list) - element_set: + atomic_number = msnp.ones(atom_name.shape, ms.int32) + else: + atomic_number = itemgetter(*atom_name_list)(element_dict) + atomic_number = Tensor( + atomic_number, ms.int32).reshape(atom_name.shape) + + if atomic_number.shape != atom_name.shape: + if atomic_number.shape[-1] == atom_name.shape[-1]: + if atomic_number.shape[0] == 1: + atomic_number = msnp.broadcast_to( + atomic_number, atom_name.shape) + elif atom_name.shape[0] == 1: + atom_name = msnp.broadcast_to( + atom_name, atomic_number.shape) + + raise ValueError('The shape of "atomic_number" '+str(atomic_number) + + ' does not match the shape of "atom_name" '+str(atom_name)+'!') + + atom_mask = F.logical_and(atomic_number > 0, atom_mass > 0) + inv_mass = msnp.where(atom_mask, msnp.reciprocal(atom_mass), 0) + + if atom_type is None: + atom_type = atom_name.copy() + else: + atom_type = np.array(atom_type) + atom_type = np.broadcast_to(atom_type, shape) + + if atom_charge is not None: + atom_charge = Tensor(atom_charge, ms.float32) + atom_charge = np.broadcast_to(atom_charge, shape) + + self.atom_name = np.concatenate((self.atom_name, atom_name), axis=-1) + self.atom_type = np.concatenate((self.atom_type, atom_type), axis=-1) + self.atom_mass = F.concat((self.atom_mass, atom_mass), -1) + self.atom_mask = F.concat((self.atom_mask, atom_mask), -1) + self.atomic_number = F.concat((self.atomic_number, atomic_number), -1) + self.inv_mass = F.concat((self.inv_mass, inv_mass), -1) + if self.atom_charge is None and atom_charge is not None: + self.atom_charge = msnp.zeros( + (self.multi_system, self.num_atoms), ms.float32) + if self.atom_charge is not None and atom_charge is None: + atom_charge = msnp.zeros((self.multi_system, 1), ms.float32) + if atom_charge is not None or self.atom_charge is not None: + self.atom_charge = F.concat((self.atom_charge, atom_charge), -1) + + self.num_atoms = self.atom_name.shape[-1] + self._index = msnp.arange(self.num_atoms) + self.system_index = self._index + self.start_index + self.natom_tensor = msnp.sum( + F.cast(self.atom_mask, ms.int32), -1, keepdims=True) + self.total_mass = msnp.sum(self.atom_mass, -1, keepdims=True) + + return self + + def broadcast_multiplicity(self, multi_system: int): + """ + Broadcast the information to the number of multiple system. + + Args: + multi_system (int): Amount of multiple system. + """ + if multi_system <= 0: + raise ValueError('multi_system must be larger than 0!') + if self.multi_system > 1: + raise ValueError('The current the number of system multi_system ('+str(self.multi_system) + + ') is larger than 1 and cannot be broadcast!') + + self.multi_system = multi_system + self.atom_name = msnp.broadcast_to(self.atom_name, (self.multi_system, -1)) + self.atom_type = msnp.broadcast_to(self.atom_mass, (self.multi_system, -1)) + self.atomic_number = msnp.broadcast_to(self.atomic_number, (self.multi_system, -1)) + self.atom_mass = msnp.broadcast_to(self.atom_mass, (self.multi_system, -1)) + self.atom_mask = msnp.broadcast_to(self.atom_mask, (self.multi_system, -1)) + self.inv_mass = msnp.broadcast_to(self.inv_mass, (self.multi_system, -1)) + self.total_mass = msnp.broadcast_to(self.total_mass, (self.multi_system, -1)) + self.natom_tensor = msnp.broadcast_to(self.natom_tensor, (self.multi_system, -1)) + if self.atom_charge is not None: + self.atom_charge = msnp.broadcast_to(self.atom_charge, (self.multi_system, -1)) + if self.bond is not None: + bond_shape = (self.multi_system,) + self.bond.shape[1:] + self.bond = msnp.broadcast_to(self.bond, bond_shape) + self.bond_mask = msnp.broadcast_to(self.bond_mask, bond_shape) + if self.head_atom is not None: + self.head_atom = msnp.broadcast_to( + self.head_atom, (self.multi_system, -1)) + if self.tail_atom is not None: + self.tail_atom = msnp.broadcast_to( + self.tail_atom, (self.multi_system, -1)) + + return self + + def set_name(self, name: str): + """ + Set residue name of this residue. + + Args: + name (str): Name of the residue. + Examples: 'SOL', 'CL'. Indicating water molecule and Na+ ion respectively. + The residue that is not defined usually called 'MOL'. + Default: 'MOL'. + """ + self._name = name + return self + + def set_start_index(self, start_index: int): + """ + Set the start index of the first atom in this residue. + + Args: + start_index (int): The start index of the first atom in this residue. + """ + if start_index < 0: + raise ValueError('The start_index cannot be smaller than 0!') + self.start_index = get_integer(start_index) + index_shift = self.start_index - self.system_index[0] + self.system_index += index_shift + return self diff --git a/MindSPONGE/applications/research/Grasp/model/__init__.py b/MindSPONGE/applications/research/Grasp/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa66240dde1a1e344968d6ea3923271e3644a6a --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/model/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''init''' +from .fold import MegaFold, compute_confidence, compute_ranking_score +from .assessment import MegaAssessment +from .evogen import MegaEvogen diff --git a/MindSPONGE/applications/research/Grasp/model/assessment.py b/MindSPONGE/applications/research/Grasp/model/assessment.py new file mode 100644 index 0000000000000000000000000000000000000000..7fbfa4ec210cad3418096c45fb625c40b80b4312 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/model/assessment.py @@ -0,0 +1,345 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""model""" +from collections import defaultdict +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import operations as P +from mindspore import Tensor, Parameter, load_checkpoint +from mindsponge.common.utils import dgram_from_positions, pseudo_beta_fn, atom37_to_torsion_angles +from mindsponge.cell.initializer import lecun_init +from module.template_embedding import TemplateEmbedding +from module.evoformer import Evoformer +from module.structure import StructureModule +from module.head import DistogramHead, PredictedLDDTHead, EstogramHead +from model.fold import caculate_constant_array, MegaFold + + +def load_weights(model_path, config): + """ + Load checkpoint as parameter dict, support both npz file and mindspore checkpoint file. + """ + ms_ckpt = load_checkpoint(model_path) + weights = defaultdict(str) + for msname in ms_ckpt: + if "msa_stack" in msname and "extra" not in msname: + for i in range(config.evoformer.msa_stack_num): + temp_name = msname.split(".") + temp_name.insert(1, str(i)) + infer_name = "fold." + ".".join(temp_name) + weights[infer_name] = ms_ckpt[msname].data.asnumpy()[i] + + for i in range(config.evoformer.msa_stack_num_assessment): + temp_name = msname.split(".") + temp_name.insert(1, str(i)) + infer_name = "assessment." + ".".join(temp_name) + weights[infer_name] = ms_ckpt[msname].data.asnumpy()[i] + else: + infer_name = "fold." + msname + weights[infer_name] = ms_ckpt[msname].data.asnumpy() + infer_name = "assessment." + msname + weights[infer_name] = ms_ckpt[msname].data.asnumpy() + + parameter_dict = defaultdict(str) + for name in weights: + parameter_dict[name] = Parameter(Tensor(weights[name]), name=name) + return parameter_dict + + +class CombineModel(nn.Cell): + """Combine MegaFold and MegaAssessment""" + + def __init__(self, config, mixed_precision): + super(CombineModel, self).__init__() + self.fold = MegaFold(config, mixed_precision=mixed_precision) + config.max_extra_msa = 2 + config.max_msa_clusters = 2 + config.slice.extra_msa_stack.msa_row_attention_with_pair_bias = 0 + self.assessment = MegaAssessment(config, mixed_precision=mixed_precision) + + def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair, final_atom_positions_recycle=None, + final_atom_mask_recycle=None, run_pretrain=True): + """construct""" + if run_pretrain: + out = self.fold(target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, + extra_has_deletion, extra_deletion_value, extra_msa_mask, residx_atom37_to_atom14, + atom37_atom_exists, residue_index, prev_pos, prev_msa_first_row, prev_pair) + else: + out = self.assessment(target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, + extra_has_deletion, extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair, final_atom_positions_recycle, + final_atom_mask_recycle) + return out + + +class MegaAssessment(nn.Cell): + """MegaAssessment""" + + def __init__(self, config, mixed_precision): + super(MegaAssessment, self).__init__() + + self.cfg = config + + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.is_training = self.cfg.is_training + self.recycle_pos = self.cfg.recycle_pos + self.recycle_features = self.cfg.recycle_features + self.max_relative_feature = self.cfg.max_relative_feature + self.num_bins = self.cfg.prev_pos.num_bins + self.min_bin = self.cfg.prev_pos.min_bin + self.max_bin = self.cfg.prev_pos.max_bin + self.template_enabled = self.cfg.template.enabled + self.template_embed_torsion_angles = self.cfg.template.embed_torsion_angles + self.extra_msa_stack_num = self.cfg.evoformer.extra_msa_stack_num_assessment + self.msa_stack_num = self.cfg.evoformer.msa_stack_num_assessment + self.chi_atom_indices, self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, \ + self.indices0, self.indices1 = caculate_constant_array(self.cfg.seq_length) + + self.preprocess_1d = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.preprocess_msa = nn.Dense(self.cfg.common.msa_feat_dim, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.common.msa_feat_dim)) + self.left_single = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.right_single = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.prev_pos_linear = nn.Dense(self.cfg.common.dgram_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.dgram_dim)) + self.pair_activations = nn.Dense(self.cfg.common.pair_in_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.pair_in_dim)) + self.extra_msa_one_hot = nn.OneHot(depth=23, axis=-1) + self.template_aatype_one_hot = nn.OneHot(depth=22, axis=-1) + self.prev_msa_first_row_norm = nn.LayerNorm([256,], epsilon=1e-5) + self.prev_pair_norm = nn.LayerNorm([128,], epsilon=1e-5) + self.one_hot = nn.OneHot(depth=self.cfg.max_relative_feature * 2 + 1, axis=-1) + self.extra_msa_activations = nn.Dense(25, self.cfg.extra_msa_channel, weight_init=lecun_init(25)) + self.template_embedding = TemplateEmbedding(self.cfg, mixed_precision) + + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.template_single_embedding = nn.Dense(57, self.cfg.msa_channel, + weight_init= + lecun_init(57, initializer_name='relu')) + self.template_projection = nn.Dense(self.cfg.msa_channel, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.msa_channel, + initializer_name='relu')) + self.relu = nn.ReLU() + self.single_activations = nn.Dense(self.cfg.msa_channel, self.cfg.seq_channel, + weight_init=lecun_init(self.cfg.msa_channel)) + extra_msa_stack = nn.CellList() + for _ in range(self.extra_msa_stack_num): + extra_msa_block = Evoformer(self.cfg, + msa_act_dim=64, + pair_act_dim=128, + is_extra_msa=True, + batch_size=None) + extra_msa_stack.append(extra_msa_block) + self.extra_msa_stack = extra_msa_stack + if self.is_training: + msa_stack = nn.CellList() + for _ in range(self.msa_stack_num): + msa_block = Evoformer(self.cfg, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + batch_size=None) + msa_stack.append(msa_block) + self.msa_stack = msa_stack + else: + self.msa_stack = Evoformer(self.cfg, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + batch_size=self.msa_stack_num) + self.idx_evoformer_block = Parameter(Tensor(0, mstype.int32), requires_grad=False) + self.evoformer_num_block_eval = Tensor(self.msa_stack_num, mstype.int32) + + self.structure_module = StructureModule(self.cfg, + self.cfg.seq_channel, + self.cfg.pair_channel) + + self.module_lddt = PredictedLDDTHead(self.cfg.heads.predicted_lddt, + self.cfg.seq_channel) + self.module_distogram = DistogramHead(self.cfg.heads.distogram, + self.cfg.pair_channel) + if self.is_training: + self.module_lddt_decoy = PredictedLDDTHead(self.cfg.heads.predicted_lddt, + self.cfg.seq_channel) + self.module_estogram = EstogramHead(first_break=self.cfg.heads.distogram.first_break, + last_break=self.cfg.heads.distogram.last_break, + num_bins=self.cfg.heads.distogram.num_bins) + + self.norm_0 = LayerNormDense(self.cfg.msa_channel, self.cfg.seq_channel) + self.norm_1 = LayerNormDense(self.cfg.msa_channel, self.cfg.seq_channel) + self.norm_2 = LayerNormDense(self.cfg.msa_channel, self.cfg.seq_channel) + self.extra_msa_length = self.cfg.max_extra_msa + self.msa_cluster_length = self.cfg.max_msa_clusters + + def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair, final_atom_positions_recycle, final_atom_mask_recycle): + """construct""" + decoy_pseudo_beta, decoy_pseudo_beta_mask = pseudo_beta_fn(aatype, final_atom_positions_recycle, + final_atom_mask_recycle) + extra_msa = mnp.zeros_like(extra_msa[:self.extra_msa_length]) + extra_has_deletion = mnp.zeros_like(extra_has_deletion[:self.extra_msa_length]) + extra_deletion_value = mnp.zeros_like(extra_deletion_value[:self.extra_msa_length]) + extra_msa_mask = mnp.zeros_like(extra_msa_mask[:self.extra_msa_length]) + msa_feat = mnp.concatenate((msa_feat[0:1], mnp.zeros_like(msa_feat[1:self.msa_cluster_length])), axis=0) + msa_mask = mnp.concatenate((msa_mask[0:1], mnp.zeros_like(msa_mask[1:self.msa_cluster_length])), axis=0) + template_aatype = mnp.concatenate((aatype[None], mnp.zeros_like(template_aatype[1:])), axis=0) + template_mask = mnp.concatenate((mnp.ones_like(template_mask[0:1]), mnp.zeros_like(template_mask[1:])), axis=0) + template_all_atom_masks[0] = final_atom_mask_recycle + template_all_atom_positions[0] = final_atom_positions_recycle + template_mask[0] = mnp.ones_like(template_mask[0]) + template_pseudo_beta_mask[0] = decoy_pseudo_beta_mask + template_pseudo_beta[0] = decoy_pseudo_beta + + preprocess_1d = self.preprocess_1d(target_feat) + preprocess_msa = self.preprocess_msa(msa_feat) + msa_activations = mnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + left_single = self.left_single(target_feat) + right_single = self.right_single(target_feat) + pair_activations = P.ExpandDims()(left_single, 1) + P.ExpandDims()(right_single, 0) + mask_2d = P.ExpandDims()(seq_mask, 1) * P.ExpandDims()(seq_mask, 0) + if self.recycle_pos: + prev_pseudo_beta = pseudo_beta_fn(aatype, prev_pos, None) + dgram = dgram_from_positions(prev_pseudo_beta, self.num_bins, self.min_bin, self.max_bin, self._type) + pair_activations += self.prev_pos_linear(dgram) + + if self.recycle_features: + prev_msa_first_row = self.prev_msa_first_row_norm(prev_msa_first_row) + msa_activations = mnp.concatenate( + (mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0), msa_activations[1:, ...]), 0) + pair_activations += self.prev_pair_norm(prev_pair) + + if self.max_relative_feature: + offset = P.ExpandDims()(residue_index, 1) - P.ExpandDims()(residue_index, 0) + rel_pos = self.one_hot(mnp.clip(offset + self.max_relative_feature, 0, 2 * self.max_relative_feature)) + pair_activations += self.pair_activations(rel_pos) + + template_pair_representation = 0 + if self.template_enabled: + template_pair_representation = self.template_embedding(pair_activations, template_aatype, + template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, + template_pseudo_beta, mask_2d) + pair_activations += template_pair_representation + msa_1hot = self.extra_msa_one_hot(extra_msa) + extra_msa_feat = mnp.concatenate((msa_1hot, extra_has_deletion[..., None], extra_deletion_value[..., None]), + axis=-1) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + extra_msa_mask_tmp = P.Transpose()(P.ExpandDims()(extra_msa_mask, -1), (2, 1, 0)) + extra_msa_norm = P.Transpose()(self.batch_matmul_trans_b(extra_msa_mask_tmp, extra_msa_mask_tmp), (1, 2, 0)) + for i in range(self.extra_msa_stack_num): + extra_msa_activations, pair_activations = \ + self.extra_msa_stack[i](extra_msa_activations, pair_activations, extra_msa_mask, extra_msa_norm, + mask_2d) + template_activations = None + if self.template_enabled and self.template_embed_torsion_angles: + num_templ, num_res = template_aatype.shape + aatype_one_hot = self.template_aatype_one_hot(template_aatype) + torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask = atom37_to_torsion_angles( + template_aatype, template_all_atom_positions, template_all_atom_masks, self.chi_atom_indices, + self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, self.indices0, self.indices1) + template_features = mnp.concatenate([aatype_one_hot, + mnp.reshape(torsion_angles_sin_cos, [num_templ, num_res, 14]), + mnp.reshape(alt_torsion_angles_sin_cos, [num_templ, num_res, 14]), + torsion_angles_mask], axis=-1) + template_activations = self.template_single_embedding(template_features) + template_activations = self.relu(template_activations) + template_activations = self.template_projection(template_activations) + msa_activations = mnp.concatenate([msa_activations, template_activations], axis=0) + torsion_angle_mask = torsion_angles_mask[:, :, 2] + msa_mask = mnp.concatenate([msa_mask, torsion_angle_mask], axis=0) + + msa_mask_tmp = P.Transpose()(P.ExpandDims()(msa_mask, -1), (2, 1, 0)) + msa_mask_norm = P.Transpose()(self.batch_matmul_trans_b(msa_mask_tmp, msa_mask_tmp), (1, 2, 0)) + + msa_decoy = [] + msa_decoy += [self.norm_0(template_activations[0]),] + + if self.is_training: + for i in range(self.msa_stack_num): + msa_activations, pair_activations = self.msa_stack[i](msa_activations, pair_activations, msa_mask, + msa_mask_norm, mask_2d) + else: + self.idx_evoformer_block = self.idx_evoformer_block * 0 + while self.idx_evoformer_block < self.evoformer_num_block_eval: + msa_activations, pair_activations = self.msa_stack(msa_activations, + pair_activations, + msa_mask, + msa_mask_norm, + mask_2d, + self.idx_evoformer_block) + self.idx_evoformer_block += 1 + + msa_decoy += [self.norm_1(msa_activations[0]),] + msa_decoy += [self.norm_2(msa_activations[-4]),] + + single_activations = self.single_activations(msa_activations[0]) + + final_atom_positions, _, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj = \ + self.structure_module(single_activations, + pair_activations, + seq_mask, + aatype, + residx_atom37_to_atom14, + atom37_atom_exists) + predicted_lddt_logits = self.module_lddt(rp_structure_module) + dist_logits, bin_edges = self.module_distogram(pair_activations) + plddt_dist, pred_mask2d, _ = self.module_estogram(dist_logits, decoy_pseudo_beta, decoy_pseudo_beta_mask) + if self.is_training: + msa_decoy = mnp.concatenate(msa_decoy, axis=-1) + decoy_logits = self.module_lddt_decoy(msa_decoy) + out = dist_logits, bin_edges, atom14_pred_positions, final_affines, angles_sin_cos_new,\ + predicted_lddt_logits, structure_traj, sidechain_frames, sidechain_atom_pos,\ + um_angles_sin_cos_new, final_atom_positions, decoy_pseudo_beta, decoy_pseudo_beta_mask, \ + decoy_logits, plddt_dist, pred_mask2d + return out + return plddt_dist + + +class LayerNormDense(nn.Cell): + """layernorm and dense layer""" + def __init__(self, inchannel, out_channel): + super(LayerNormDense, self).__init__() + self.norm = nn.LayerNorm([inchannel,], epsilon=1e-5) + self.act = nn.Dense(inchannel, out_channel, weight_init=lecun_init(inchannel)).to_float(mstype.float16) + + def construct(self, single_act): + """construct""" + out = self.norm(single_act.astype(mstype.float32)).astype(mstype.float16) + out = self.act(out) + + return out diff --git a/MindSPONGE/applications/research/Grasp/model/evogen.py b/MindSPONGE/applications/research/Grasp/model/evogen.py new file mode 100644 index 0000000000000000000000000000000000000000..7056799ebab561d7a28701d0376a79378885bd70 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/model/evogen.py @@ -0,0 +1,285 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evogen""" +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +import mindspore.nn.probability.distribution as msd + +from module.evogen_block import EvoformerIteration, LatentBlock, EvoGenFeatProcess, LatentNormal +from model.fold import MegaFold +import numpy as np +from mindsponge.cell.initializer import lecun_init + + +class MsaGen(nn.Cell): + '''MsaGen''' + + def __init__( + self, + config, + ): + super().__init__() + self.config = config.model.embeddings_and_evoformer + self.config_latent = config.model.latent + + self.evoformer_num_block = self.config.evoformer_num_block + self.msa_act_dim = self.config.msa_channel + self.pair_act_dim = self.config.pair_channel + self.num_noise = self.config_latent.num_noise + self.noise_layers = self.config_latent.noise_layers + self.latent_dims = self.config_latent.latent_dim_tuple + self.del_num_bins = self.config.del_num_bins + + evoformer_encoder_blocks = nn.CellList() + evoformer_decoder_blocks = nn.CellList() + for i in range(self.evoformer_num_block): + evoformer_block = EvoformerIteration(config, + msa_act_dim=self.msa_act_dim, + pair_act_dim=self.pair_act_dim, + encoding=True, + ) + evoformer_encoder_blocks.append(evoformer_block) + + evoformer_block = EvoformerIteration(config, + msa_act_dim=self.msa_act_dim, + pair_act_dim=self.pair_act_dim, + encoding=False, + ) + evoformer_decoder_blocks.append(evoformer_block) + self.evoformer_encoder = evoformer_encoder_blocks + self.evoformer_decoder = evoformer_decoder_blocks + + self.latent_normal = LatentNormal() + latent_blocks = nn.CellList() + for i in range(self.num_noise): + lt_block = LatentBlock(config, + msa_dim=self.msa_act_dim, + latent_dim=self.latent_dims[i], + ) + latent_blocks.append(lt_block) + self.latent_block = latent_blocks + + self.num_aa_types = config.global_config.num_aa_types + self.num_msa_types = self.num_aa_types + 1 + self.pair_bins = self.config.num_buckets * 2 + 1 + self.num_del_num_bins = len(self.del_num_bins) + self.del_num_bins = Tensor(self.del_num_bins, mstype.float32) + + self.preprocess_1d = nn.Dense(self.num_aa_types, self.msa_act_dim, weight_init=lecun_init(self.num_aa_types)) + self.preprocess_msa = nn.Dense(self.num_msa_types + 2, self.msa_act_dim, + weight_init=lecun_init(self.num_msa_types)) + self.left_single = nn.Dense(self.num_aa_types, self.pair_act_dim, weight_init=lecun_init(self.num_aa_types)) + self.right_single = nn.Dense(self.num_aa_types, self.pair_act_dim, weight_init=lecun_init(self.num_aa_types)) + self.pair_activations = nn.Dense(self.pair_bins, self.pair_act_dim, weight_init=lecun_init(self.pair_bins)) + + np_mask = np.ones(shape=(self.num_msa_types), dtype=np.float32) + np_mask[20], np_mask[22] = 0, 0 + self.reconstruct_mask = Tensor(np_mask, mstype.float32) + + self.reconstruct_head = nn.Dense(self.msa_act_dim, self.num_msa_types, weight_init='zeros', has_bias=True) + + self.reconstruct_head_query_new = nn.Dense(self.msa_act_dim, self.num_msa_types, weight_init='zeros', + has_bias=True) + + self.reconstruct_head_hasdel = nn.Dense(self.msa_act_dim, 1, weight_init='zeros', has_bias=True, + bias_init='ones') + self.reconstruct_head_delnum = nn.Dense(self.msa_act_dim, self.num_del_num_bins, weight_init='zeros', + has_bias=True) + + self.matmul = P.MatMul(transpose_b=True) + self.expand_dims = P.ExpandDims() + + def construct(self, q_raw_feat, msa_raw_feat, pair_raw_feat, msa_mask, pair_mask, context_mask, target_mask, + res_idx=None, random_feat=None): + '''construct''' + mask_tmp = P.Transpose()(msa_mask * context_mask, (1, 0)) + mask_norm = self.matmul(mask_tmp, mask_tmp) + mask_norm = self.expand_dims(mask_norm, -1) + + msa_activations, pair_activations = self._init_feat(q_raw_feat, msa_raw_feat, pair_raw_feat) + msa_act_list = [msa_activations] + pair_act_list = [pair_activations] + for i in range(self.evoformer_num_block): + msa_activations, pair_activations = self.evoformer_encoder[i](msa_activations, pair_activations, \ + msa_mask, pair_mask, context_mask, + mask_norm=mask_norm, res_idx=res_idx) + msa_act_list.append(msa_activations) + pair_act_list.append(pair_activations) + + msa_recon_act = P.Tile()(self.expand_dims(msa_activations[0], 0), (msa_activations.shape[0], 1, 1)) + + kl_all = [] + i_layer = 0 + for i in range(self.num_noise): + layers = self.noise_layers[i] + for _ in range(layers): + msa_recon_act, _ = self.evoformer_decoder[i_layer](msa_recon_act, pair_act_list[-(i_layer + 1)], \ + msa_mask, pair_mask, context_mask, res_idx=res_idx) + i_layer += 1 + + eps = None + if random_feat is not None: + eps = random_feat[i] + eps = eps[:, :, :self.latent_dims[i]] + + latent_block_result = self.latent_block[i](msa_recon_act, msa_act_list[-(i_layer + 1)], msa_mask, + context_mask, target_mask, eps) + msa_recon_act, mu_prior, log_sigma_prior, mu_posterior, log_sigma_posterior = latent_block_result + + mu_posterior[0] = mu_prior[0] * 1. + log_sigma_posterior[0] = log_sigma_prior[0] * 1. + + kl_per_var = self.latent_normal(mu_posterior, log_sigma_posterior, mu_prior, log_sigma_prior) + kl_all.append(kl_per_var.sum(axis=-1)) + + if i == self.num_noise - 1: + for j in range(i_layer, self.evoformer_num_block): + msa_recon_act, _ = self.evoformer_decoder[j](msa_recon_act, pair_act_list[-(j + 1)], \ + msa_mask, pair_mask, context_mask, mask_norm=mask_norm, + res_idx=res_idx) + + q_act = msa_recon_act[0] + q_logits = self.reconstruct_head_query_new(q_act) + q_logits = q_logits.astype(mstype.float32) + 1e9 * P.Reshape()(self.reconstruct_mask - 1., (1, -1)) + q_logits += 1e9 * P.Reshape()(self.reconstruct_mask - 1., (1, -1)) + + recon_logits = self.reconstruct_head(msa_recon_act).astype(mstype.float32) + recon_logits += 1e9 * P.Reshape()(self.reconstruct_mask - 1., (1, 1, -1)) + + hasdel_logits = self.reconstruct_head_hasdel(msa_recon_act).astype(mstype.float32) + delnum_logits = self.reconstruct_head_delnum(msa_recon_act).astype(mstype.float32) + + logits = P.Concat(0)((P.ExpandDims()(q_logits, 0), recon_logits[1:])) + + no_del_prob, mean_delnum = self._compute_del_num(hasdel_logits, delnum_logits) + + return logits, no_del_prob, mean_delnum + + def _init_feat(self, q_raw_feat, msa_raw_feat, pair_raw_feat): + '''init_feat''' + q_feat = self.preprocess_1d(q_raw_feat) + msa_feat = self.preprocess_msa(msa_raw_feat) + msa_activations = self.expand_dims(q_feat, 0) + msa_feat + + pair_activations = self.pair_activations(pair_raw_feat) + left_single = self.left_single(q_raw_feat) + right_single = self.right_single(q_raw_feat) + pair_activations += self.expand_dims(left_single, 1) + self.expand_dims(right_single, 0) + return msa_activations, pair_activations + + def _compute_del_num(self, hasdel_logits, delnum_logits): + '''compute_del_num''' + hasdel_logits = P.Squeeze(-1)(hasdel_logits.astype(mstype.float32)) + no_del_prob = P.Sigmoid()(hasdel_logits) + mean_delnum = P.Softmax(-1)(delnum_logits.astype(mstype.float32)) * P.Reshape()(self.del_num_bins, (1, 1, -1)) + mean_delnum = P.ReduceSum()(mean_delnum, -1) + return no_del_prob, mean_delnum + + +class MegaEvogen(nn.Cell): + '''MegaEvogen''' + + def __init__(self, msa_model_config, model_cfg, mixed_precision): + super().__init__() + self.msa_vae = MsaGen(msa_model_config) + self.feat_process = EvoGenFeatProcess( + config=msa_model_config, + ) + self.megafold = MegaFold(model_cfg, mixed_precision) + + self.softmax_temperature = msa_model_config.train.softmax_temperature + self.use_gumbel_trick = msa_model_config.train.use_gumbel_trick + self.use_dark_knowledge = msa_model_config.train.use_dark_knowledge + self.uniform = msd.Uniform(1e-5, 1. - 1e-5, dtype=mstype.float32) + + augmented_msa_depth = min(msa_model_config.train.augmented_msa_depth, msa_model_config.train.max_msa_clusters) + augmented_msa_mask = np.ones([msa_model_config.train.max_msa_clusters]) + augmented_msa_mask[augmented_msa_depth:] *= 0 + self.augmented_msa_mask = Tensor(augmented_msa_mask, mstype.float32) + self.onehot = nn.OneHot(depth=msa_model_config.global_config.num_aa_types + 1) + self.concat = P.Concat(-1) + self.softmax = nn.Softmax() + + def construct(self, target_feat, seq_mask, aatype, residx_atom37_to_atom14, atom37_atom_exists, + residue_index, msa_mask, msa_data, query_data, addition_data, random_data, + random_mask, fake_template_aatype, fake_template_all_atom_masks, + fake_template_all_atom_positions, fake_template_mask, + fake_template_pseudo_beta_mask, fake_template_pseudo_beta, + fake_extra_msa, fake_extra_has_deletion, fake_extra_deletion_value, + fake_extra_msa_mask, prev_pos, prev_msa_first_row, prev_pair): + '''construct''' + msa_mask_new = msa_mask[:, 0].astype(mstype.float32) + context_mask = random_mask[:, 0].astype(mstype.float32) + target_mask = random_mask[:, 1].astype(mstype.float32) + + context_mask_new = context_mask * msa_mask_new + target_mask_new = target_mask * context_mask * msa_mask_new + + random_mask_correct = P.Stack(-1)((context_mask_new, target_mask_new)) + msa_input = self.concat((msa_data, P.ExpandDims()(msa_mask, -1))) + + _, feat_tuple = self.feat_process(query_data, msa_input, addition_data, + random_data, random_mask_correct) + q_raw_feat, msa_raw_feat, pair_raw_feat, msa_mask, pair_mask, context_mask, target_mask, \ + res_idx, random_feat = feat_tuple + msa_logits, no_del_prob, mean_delnum = self.msa_vae(q_raw_feat, msa_raw_feat, + pair_raw_feat, msa_mask, pair_mask, + context_mask, target_mask, res_idx, + random_feat) + + msa_prob = self.softmax(msa_logits * self.softmax_temperature) + + msa_reduce = P.Argmax(axis=-1)(msa_logits) + msa = self.onehot(msa_reduce) + + if self.use_gumbel_trick: + gumbel = self.uniform.sample(msa_logits.shape).astype(msa_logits.dtype) + msa_reduce = P.Argmax(axis=-1)(msa_logits / self.softmax_temperature + gumbel) + msa = self.onehot(msa_reduce) + + if self.use_dark_knowledge: + msa = msa_prob + + pad_zero = P.ZerosLike()(msa)[:, :, :1] + msa_feat_new_generate = self.concat((msa, pad_zero, pad_zero, msa, pad_zero)) + + has_del_prob = 1. - no_del_prob + del_num_feat = has_del_prob * mean_delnum + has_del_prob = P.ExpandDims()(has_del_prob, -1) + del_num_feat = P.ExpandDims()(del_num_feat, -1) + has_del_prob[0] *= 0. + del_num_feat[0] *= 0. + msa_feat_new_reconstruct = self.concat((msa, has_del_prob, del_num_feat, msa, del_num_feat)) + + recon_mask = target_mask_new + recon_mask_new = P.Reshape()(recon_mask, (-1, 1, 1)) + gen_mask_new = P.Reshape()((1. - recon_mask), (-1, 1, 1)) + msa_feat_new = recon_mask_new * msa_feat_new_reconstruct + gen_mask_new * msa_feat_new_generate + msa_mask_af2 = self.augmented_msa_mask + msa_mask_new = P.ExpandDims()(msa_mask_af2, 1) * P.ExpandDims()(seq_mask, 0) + + msa_feat_new = msa_feat_new * P.ExpandDims()(msa_mask_new, -1) + + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = \ + self.megafold(target_feat, msa_feat_new, msa_mask_new, seq_mask, aatype, + fake_template_aatype, fake_template_all_atom_masks, fake_template_all_atom_positions, + fake_template_mask, fake_template_pseudo_beta_mask, fake_template_pseudo_beta, fake_extra_msa, + fake_extra_has_deletion, fake_extra_deletion_value, fake_extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair) + result = prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits + return result diff --git a/MindSPONGE/applications/research/Grasp/model/fold.py b/MindSPONGE/applications/research/Grasp/model/fold.py new file mode 100644 index 0000000000000000000000000000000000000000..a21098f52cd26b9a5777b2c15fc4c199baaf0178 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/model/fold.py @@ -0,0 +1,660 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""model""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore import Parameter +from mindspore import ops +import mindsponge.common.residue_constants as residue_constants +from mindsponge1.common.utils import dgram_from_positions, pseudo_beta_fn, atom37_to_torsion_angles +from mindsponge1.data.data_transform import get_chi_atom_pos_indices +from mindsponge1.cell.initializer import lecun_init +from module.template_embedding import MultimerTemplateEmbedding #TemplateEmbedding +from module.evoformer import MultimerEvoformer #Evoformer +# from module.structure import StructureModule +from module.structure_multimer import MultimerStructureModule +from module.head import DistogramHead, ExperimentallyResolvedHead, MaskedMsaHead, \ + PredictedLDDTHead, PredictedAlignedErrorHead + +from common.utils import compute_chi_angles, ComputeChiAngles +from scipy.special import softmax +from restraint_sample import BINS +# from mindsponge1.cell.dense import ProcessSBR +from mindspore.communication import get_rank + +from typing import Dict, Optional, Tuple +import numpy as np +import scipy.special + +def caculate_constant_array(seq_length): + '''constant array''' + chi_atom_indices = np.array(get_chi_atom_pos_indices()).astype(np.int32) + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = np.array(chi_angles_mask).astype(np.float32) + mirror_psi_mask = np.float32(np.asarray([1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]) + chi_pi_periodic = np.float32(np.array(residue_constants.chi_pi_periodic)) + + indices0 = np.arange(4).reshape((-1, 1, 1, 1, 1)).astype("int32") # 4 batch + indices0 = indices0.repeat(seq_length, axis=1) # seq_length sequence length + indices0 = indices0.repeat(4, axis=2) # 4 chis + indices0 = indices0.repeat(4, axis=3) # 4 atoms + + indices1 = np.arange(seq_length).reshape((1, -1, 1, 1, 1)).astype("int32") + indices1 = indices1.repeat(4, axis=0) + indices1 = indices1.repeat(4, axis=2) + indices1 = indices1.repeat(4, axis=3) + + constant_array = [chi_atom_indices, chi_angles_mask, mirror_psi_mask, chi_pi_periodic, indices0, indices1] + constant_array = [Tensor(val) for val in constant_array] + return constant_array + + +def compute_confidence(predicted_lddt_logits, return_lddt=False): + """compute confidence""" + + num_bins = predicted_lddt_logits.shape[-1] + bin_width = 1 / num_bins + start_n = bin_width / 2 + plddt = compute_plddt(predicted_lddt_logits, start_n, bin_width) + confidence = np.mean(plddt) + if return_lddt: + return confidence, plddt + + return confidence + + +def compute_plddt(logits, start_n, bin_width): + """Computes per-residue pLDDT from logits. + + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + + Returns: + plddt: [num_res] per-residue pLDDT. + """ + bin_centers = np.arange(start=start_n, stop=1.0, step=bin_width) + probs = softmax(logits, axis=-1) + predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) + return predicted_lddt_ca * 100 + + +def _calculate_bin_centers(breaks: np.ndarray): + """Gets the bin centers from the bin edges. + + Args: + breaks: [num_bins - 1] the error bin edges. + + Returns: + bin_centers: [num_bins] the error bin centers. + """ + step = (breaks[1] - breaks[0]) + + # Add half-step to get the center + bin_centers = breaks + step / 2 + # Add a catch-all bin at the end. + bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]], + axis=0) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: np.ndarray, + aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Calculates expected aligned distance errors for every pair of residues. + + Args: + alignment_confidence_breaks: [num_bins - 1] the error bin edges. + aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted + probs for each error bin, for each pair of residues. + + Returns: + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + + # Tuple of expected aligned distance error and max possible error. + return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1), + np.asarray(bin_centers[-1])) + + +def compute_predicted_aligned_error( + logits: np.ndarray, + breaks: np.ndarray) -> Dict[str, np.ndarray]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins - 1] the error bin edges. + + Returns: + aligned_confidence_probs: [num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + aligned_confidence_probs = scipy.special.softmax( + logits, + axis=-1) + predicted_aligned_error, max_predicted_aligned_error = ( + _calculate_expected_aligned_error( + alignment_confidence_breaks=breaks, + aligned_distance_error_probs=aligned_confidence_probs)) + return { + 'aligned_confidence_probs': aligned_confidence_probs, + 'predicted_aligned_error': predicted_aligned_error, + 'max_predicted_aligned_error': max_predicted_aligned_error, + } + + +def predicted_tm_score( + logits: np.ndarray, + breaks: np.ndarray, + residue_weights: Optional[np.ndarray] = None, + asym_id: Optional[np.ndarray] = None, + interface: bool = False) -> np.ndarray: + """Computes predicted TM alignment or predicted interface TM alignment score. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for + ipTM calculation, i.e. when interface=True. + interface: If True, interface predicted TM score is computed. + + Returns: + ptm_score: The predicted TM alignment or the predicted iTM score. + """ + + # residue_weights has to be in [0, 1], but can be floating-point, i.e. the + # exp. resolved head's probability. + if residue_weights is None: + residue_weights = np.ones(logits.shape[0]) + + bin_centers = _calculate_bin_centers(breaks) + + num_res = int(np.sum(residue_weights)) + # Clip num_res to avoid negative/undefined d0. + clipped_num_res = max(num_res, 19) + + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick + # "Scoring function for automated assessment of protein structure template + # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf + + d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 + + # Convert logits to probs. + probs = scipy.special.softmax(logits, axis=-1) + + # TM-Score term for every bin. + tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0)) + # E_distances tm(distance). + predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1) + + pair_mask = np.ones(shape=(num_res, num_res), dtype=bool) + if interface: + pair_mask *= asym_id[:, None] != asym_id[None, :] + + predicted_tm_term *= pair_mask + + pair_residue_weights = pair_mask * ( + residue_weights[None, :] * residue_weights[:, None]) + normed_residue_mask = pair_residue_weights / (1e-8 + np.sum( + pair_residue_weights, axis=-1, keepdims=True)) + per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1) + return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) + + +def compute_ranking_score(logits, breaks, asym_id): + # print(logits.shape, breaks.shape, asym_id.shape) + iptm = predicted_tm_score(logits, breaks, asym_id=asym_id, interface=True) + ptm = predicted_tm_score(logits, breaks) + return 0.8*iptm + 0.2*ptm + + +class MegaFold(nn.Cell): + """MegaFold""" + + def __init__(self, config, mixed_precision, device_num): + super(MegaFold, self).__init__() + self.dump = ops.TensorDump() + self.cfg = config + + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + + self.is_training = self.cfg.is_training + self.recycle_pos = self.cfg.recycle_pos + self.recycle_features = self.cfg.recycle_features + self.max_relative_feature = self.cfg.max_relative_feature + self.use_chain_relative = self.cfg.multimer.embeddings_and_evoformer.use_chain_relative + self.max_relative_chain = self.cfg.multimer.embeddings_and_evoformer.max_relative_chain + + self.num_bins = self.cfg.prev_pos.num_bins + self.min_bin = self.cfg.prev_pos.min_bin + self.max_bin = self.cfg.prev_pos.max_bin + self.template_enabled = self.cfg.template.enabled + self.extra_msa_stack_num = self.cfg.evoformer.extra_msa_stack_num + self.msa_stack_num = self.cfg.evoformer.msa_stack_num + self.chi_atom_indices, self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, \ + self.indices0, self.indices1 = caculate_constant_array(self.cfg.seq_length) + + # self.contact_one_hot = nn.OneHot(depth=2, axis=-1) + self.sbr_act_dim = 128 + self.sbr_act1 = nn.Dense(len(BINS)+1, self.sbr_act_dim, weight_init=lecun_init(len(BINS)+1), activation='relu') + self.sbr_act2 = nn.Dense(self.sbr_act_dim, self.sbr_act_dim, weight_init=lecun_init(self.sbr_act_dim)) + # self.sbr_gate = nn.Dense(self.sbr_act_dim+self.cfg.pair_channel, self.sbr_act_dim, weight_init='zeros', bias_init='ones') + self.sigmoid = nn.Sigmoid() + # self.preprocess_contact = nn.Dense(1, 128, lecun_init(15)).to_float(mstype.float16) + # self.process_sbr = ProcessSBR(len(BINS)+1, 32, gate=True, pair_input_dim=self.cfg.pair_channel) + # self.process_sbr = ProcessSBR(len(BINS)+1, 32) + + # print("debug self.cfg.common.target_feat_dim", self.cfg.common.target_feat_dim) + self.preprocess_1d = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.preprocess_msa = nn.Dense(self.cfg.common.msa_feat_dim, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.common.msa_feat_dim)) + # self.preprocess_msa + self.left_single = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.right_single = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.prev_pos_linear = nn.Dense(self.cfg.common.dgram_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.dgram_dim)) + + self.extra_msa_one_hot = nn.OneHot(depth=23, axis=-1) + # self.extra_msa_one_hot.onehot.shard(((1,2,1),)) + # self.extra_msa_one_hot = P.OneHot(-1).shard(((1,2,1),(),(),())) + self.template_aatype_one_hot = nn.OneHot(depth=22, axis=-1) + self.prev_msa_first_row_norm = nn.LayerNorm([256,], epsilon=1e-5) + self.prev_pair_norm = nn.LayerNorm([128,], epsilon=1e-5) + # self.prev_pair_norm.layer_norm.shard(((1, device_num, 1), (1,), (1,))) + if self.use_chain_relative: + self.rel_pos_one_hot = nn.OneHot(depth=self.cfg.max_relative_feature * 2 + 2, axis=-1) # 32 * 2 + 2 = 66 + self.rel_chain_one_hot = nn.OneHot(depth=self.max_relative_chain * 2 + 2, axis=-1) # 2 * 2 + 2 = 6 + self.position_activations = nn.Dense(self.cfg.multimer.pair_in_dim, self.cfg.pair_channel, #73 + weight_init=lecun_init(self.cfg.multimer.pair_in_dim)) + self.interface_activations = nn.Dense(2, self.cfg.pair_channel, #2 + weight_init='zeros', + has_bias=False) + else: + self.one_hot = nn.OneHot(depth=self.cfg.max_relative_feature * 2 + 1, axis=-1) # 65 + self.position_activations = nn.Dense(self.cfg.common.pair_in_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.pair_in_dim)) + self.extra_msa_activations = nn.Dense(25, self.cfg.extra_msa_channel, weight_init=lecun_init(25)) + self.template_embedding = MultimerTemplateEmbedding(self.cfg, device_num, mixed_precision) + + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.template_single_embedding = nn.Dense(34, self.cfg.msa_channel, + weight_init= + lecun_init(34, initializer_name='relu')) + self.template_projection = nn.Dense(self.cfg.msa_channel, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.msa_channel, + initializer_name='relu')) + self.relu = nn.ReLU() + self.single_activations = nn.Dense(self.cfg.msa_channel, self.cfg.seq_channel, + weight_init=lecun_init(self.cfg.msa_channel)) + extra_msa_stack = nn.CellList() + for _ in range(self.extra_msa_stack_num): + extra_msa_block = MultimerEvoformer(self.cfg, + msa_act_dim=64, + pair_act_dim=128, + is_extra_msa=True, + batch_size=None, + device_num=device_num) + extra_msa_stack.append(extra_msa_block) + self.extra_msa_stack = extra_msa_stack + self.aligned_error = PredictedAlignedErrorHead(self.cfg.heads.predicted_aligned_error, + self.cfg.pair_channel) + if self.is_training: + msa_stack = nn.CellList() + for _ in range(self.msa_stack_num): + msa_block = MultimerEvoformer(self.cfg, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + batch_size=None, + device_num=device_num) + msa_stack.append(msa_block) + self.msa_stack = msa_stack + self.module_distogram = DistogramHead(self.cfg.heads.distogram, + self.cfg.pair_channel) + self.module_exp_resolved = ExperimentallyResolvedHead(self.cfg.seq_channel) + self.module_mask = MaskedMsaHead(self.cfg.heads.masked_msa, + self.cfg.msa_channel) + else: + # print("debug", self.cfg, self.msa_stack_num) + self.msa_stack = MultimerEvoformer(self.cfg, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + batch_size=self.msa_stack_num, + device_num=device_num) + self.idx_evoformer_block = Parameter(Tensor(0, mstype.int32), requires_grad=False) + self.evoformer_num_block_eval = Tensor(self.msa_stack_num, mstype.int32) + + self.structure_module = MultimerStructureModule(self.cfg, + self.cfg.seq_channel, + self.cfg.pair_channel, + device_num) + # raw notion + # self.structure_module = StructureModule(self.cfg, + # self.cfg.seq_channel, + # self.cfg.pair_channel) + + self.module_lddt = PredictedLDDTHead(self.cfg.heads.predicted_lddt, + self.cfg.seq_channel) + self.add_2 = P.Add().shard(((1, device_num, 1), (1, device_num, 1))) + self.concat_0_2 = P.Concat(0).shard(((1, device_num, 1), (1, device_num, 1))) + self.concat_e_3 = P.Concat(-1).shard(((1, device_num, 1), (1, device_num, 1), (1, device_num, 1))) + self.cast3 = P.Cast().shard(((1, device_num, 1),)) + self.cast2 = P.Cast().shard(((1, device_num),)) + self.expand2 = P.ExpandDims().shard(((1, device_num),)) + self.concat_e_4 = P.Concat(-1).shard(((1, device_num, 1),(1, device_num, 1), (1, device_num, 1), (1, device_num, 1))) + self.concat_0_2_2 = P.Concat(0).shard(((1, device_num), (1, device_num))) + self.compute_chi_angles = ComputeChiAngles(device_num) + self.squeeze2 = P.Squeeze().shard(((1, device_num, 1),)) + self.slice_ops2 = P.Slice().shard(((1, device_num, 1),)) + self.allgather3 = P.StridedSlice().shard(((1, 1, 1),)) + self.allgather2 = P.StridedSlice().shard(((1, 1),)) + + + def _relative_encoding(self, residue_index, asym_id, sym_id, entity_id, interface_mask): + """Add relative position encoding""" + rel_feats = [] + asym_id_same = mnp.equal(P.ExpandDims()(asym_id, 1), P.ExpandDims()(asym_id, 0)).astype(mstype.int32) # seq_len * seq_len + offset = P.ExpandDims()(residue_index, 1) - P.ExpandDims()(residue_index, 0) # seq_len * seq_len + clipped_offset = mnp.clip( + offset + self.max_relative_feature, xmin=0, xmax= 2 * self.max_relative_feature) + interface_feat = None + if self.use_chain_relative: + final_offset = mnp.where(asym_id_same, clipped_offset, + (2 * self.max_relative_feature + 1) * + mnp.ones_like(clipped_offset)) + rel_pos = self.rel_pos_one_hot(final_offset) # seq_len * seq_len * 66 + rel_feats.append(rel_pos) + # entity_id_same = mnp.equal(entity_id[:, None], entity_id[None, :]) # seq_len * seq_len * 1 + entity_id_same = mnp.equal(P.ExpandDims()(entity_id, 1), P.ExpandDims()(entity_id, 0)).astype(mstype.int32) # seq_len * seq_len * 1 + rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) + rel_sym_id = P.ExpandDims()(sym_id, 1) - P.ExpandDims()(sym_id, 0) + max_rel_chain = self.max_relative_chain + clipped_rel_chain = mnp.clip( + rel_sym_id + max_rel_chain, xmin=0, xmax=2 * max_rel_chain) + # entity_id_same = entity_id_same.astype(mstype.int32) + final_rel_chain = mnp.where(entity_id_same, clipped_rel_chain, + (2 * max_rel_chain + 1) * + mnp.ones_like(clipped_rel_chain)) + rel_chain = self.rel_chain_one_hot(final_rel_chain.astype(mstype.int32)) # seq_len * seq_len * 6 + rel_feats.append(rel_chain) + interface_feat = mnp.concatenate([mnp.tile(interface_mask[:, None, None], (1, len(interface_mask), 1)), mnp.tile(interface_mask[None, :, None], (len(interface_mask), 1, 1))], axis=-1) + else: + rel_pos = self.one_hot(clipped_offset) + rel_feats.append(rel_pos) + rel_feat = mnp.concatenate(rel_feats, axis=-1) # seq_len * seq_len * 73 for multimer + return self.position_activations(rel_feat)+self.interface_activations(interface_feat)#, rel_feat, interface_feat + + def construct(self, aatype, residue_index, template_aatype, template_all_atom_masks, template_all_atom_positions, + asym_id, sym_id, entity_id, seq_mask, msa_mask, target_feat, msa_feat, + extra_msa, extra_msa_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, + sbr, sbr_mask, interface_mask, prev_pos, prev_msa_first_row, prev_pair): + # def construct(self, extra_msa_activations, pair_activations, extra_msa_mask, extra_msa_norm, mask_2d, rank): + # def construct(self, msa_feat): + """construct""" + # # print("debug target_feat", target_feat, type(target_feat[0][0]), target_feat[0][0].dtype) + preprocess_1d = self.preprocess_1d(target_feat) # raw target_feat 256 21 1, 128 256 (1,2,1) + # self.dump(f"msa_feat_23_rank{rank}", msa_feat) + preprocess_msa = self.preprocess_msa(msa_feat) # raw (508 128 49) (49 256) + # self.dump(f"preprocess_msa_23_rank{rank}", preprocess_msa) + + # msa_activations = mnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + msa_activations = self.add_2(mnp.expand_dims(preprocess_1d, axis=0), preprocess_msa) + # print("debug megafold msa_activations", preprocess_msa) + left_single = self.left_single(target_feat) + right_single = self.right_single(target_feat) + # # print("debug megafold left_single", left_single) + # # print("debug megafold right_single", right_single) + + pair_activations = P.ExpandDims()(left_single, 1) + P.ExpandDims()(right_single, 0) + + # # print("pair_activations 410", pair_activations) + mask_2d = P.ExpandDims()(seq_mask, 1) * P.ExpandDims()(seq_mask, 0) + if self.recycle_pos: + prev_pseudo_beta, _ = pseudo_beta_fn(aatype, prev_pos, atom37_atom_exists) + dgram = dgram_from_positions(prev_pseudo_beta, self.num_bins, self.min_bin, self.max_bin, self._type) + pair_activations += self.prev_pos_linear(dgram) + # print("pair_activations 418", pair_activations) + # self.dump(f"pair_activations_437_{rank}_copy", pair_activations) + if self.recycle_features: + # print("debug prev_msa_first_row1", prev_msa_first_row) + prev_msa_first_row = self.prev_msa_first_row_norm(prev_msa_first_row) + # print("debug prev_msa_first_row2", prev_msa_first_row) + # print("debug mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0): ", mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0)) + # msa_activations = mnp.concatenate( + # (mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0), msa_activations[1:, ...]), 0) + # msa_activations = self.concat_0_2((mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0), + # P.Cast()(msa_activations[1:, ...], mstype.float32))) + msa_activations = self.concat_0_2((mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0), + msa_activations[1:, ...])) + pair_activations += self.prev_pair_norm(prev_pair) + # print("msa_activations: ", msa_activations) + # print("pair_activations: ", pair_activations) + + if self.max_relative_feature: + pair_activations += self._relative_encoding(residue_index, asym_id, sym_id, entity_id, interface_mask) + # print("pair_activations 429", pair_activations) + # return pair_activations + # template_pair_representation = 0 + if self.template_enabled: + multichain_mask = mnp.equal(P.ExpandDims()(asym_id, 1), P.ExpandDims()(asym_id, 0)) + # print("debug template_embedding", "template_aatype", template_aatype, "template_all_atom_masks", template_all_atom_masks, + # "template_all_atom_positions", template_all_atom_positions, "mask_2d", mask_2d, "multichain_mask", multichain_mask) + template_pair_representation = self.template_embedding(pair_activations, template_aatype, + template_all_atom_masks, + template_all_atom_positions, mask_2d, + multichain_mask) + # print("pair_activations 438", pair_activations) + # print("template_pair_representation, self.template_embedding", template_pair_representation) + pair_activations += template_pair_representation + # print("pair_activations", pair_activations) + msa_1hot = self.extra_msa_one_hot(extra_msa) + # msa_1hot = self.extra_msa_one_hot(extra_msa, 23, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)) + extra_has_deletion = self.cast2(extra_msa_deletion_value > 0, extra_msa_deletion_value.dtype) + # extra_msa_feat = mnp.concatenate((msa_1hot, extra_has_deletion[..., None], extra_msa_deletion_value[..., None]), + # axis=-1) + extra_msa_feat = self.concat_e_3((msa_1hot, self.cast3(self.expand2(extra_has_deletion, -1), mstype.float32), self.cast3(self.expand2(extra_msa_deletion_value, -1), mstype.float32))) + # print("extra_msa_feat: ", extra_msa_feat) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + # print("extra_msa_activations_raw: ", extra_msa_activations) + extra_msa_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(extra_msa_mask, extra_msa_mask), -1) + # # print("extra_msa_norm: ", extra_msa_norm) + # # print(extra_msa_stack_num) + # self.dump(f"extra_msa_activations_{rank}", extra_msa_activations) + # self.dump(f"pair_activations_{rank}", pair_activations) + # self.dump(f"extra_msa_mask_{rank}", extra_msa_mask) + # self.dump(f"extra_msa_norm_{rank}", extra_msa_norm) + # self.dump(f"mask_2d_{rank}", mask_2d) + # return self.extra_msa_stack_num, extra_msa_activations, pair_activations, extra_msa_mask, extra_msa_norm, mask_2d + for i in range(self.extra_msa_stack_num): + # print("pair_activations 449, round", i, pair_activations) + extra_msa_activations, pair_activations = \ + self.extra_msa_stack[i](extra_msa_activations, pair_activations, extra_msa_mask, extra_msa_norm, + mask_2d) + # print("extra_msa_activations: ", extra_msa_activations) + # print("pair_activations: ", pair_activations) + # return pair_activations + if self.template_enabled: + aatype_one_hot = self.template_aatype_one_hot(template_aatype) + chi_angles, chi_mask = self.compute_chi_angles(template_aatype, + template_all_atom_positions, + template_all_atom_masks, + self.chi_atom_indices, + self.chi_angles_mask, + self.indices0, + self.indices1) + + + template_features = self.concat_e_4([aatype_one_hot, + mnp.sin(chi_angles) * chi_mask, + mnp.cos(chi_angles) * chi_mask, + chi_mask]) + # template_mask = chi_mask[:, :, 0] + template_mask = self.squeeze2(self.slice_ops2(chi_mask, (0,0,0), (chi_mask.shape[0], chi_mask.shape[1], 1))) + template_activations = self.template_single_embedding(template_features) + template_activations = self.relu(template_activations) + template_activations = self.template_projection(template_activations) + # msa_activations = mnp.concatenate([msa_activations, template_activations], axis=0) + # print("msa_activations:", msa_activations) + # msa_activations = self.concat_0_2([msa_activations, self.cast3(template_activations, mstype.float32)]) + msa_activations = self.concat_0_2([msa_activations, template_activations]) + + # print("msa_mask's type: ", msa_mask) + # print("template_mask's type: ", template_mask) + msa_mask = self.concat_0_2_2([self.cast2(msa_mask, mstype.float32), template_mask]) + + # print("msa_mask: ", msa_mask) + # print("msa_activations:", msa_activations) + msa_mask_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(msa_mask, msa_mask), -1) + # print("msa_mask_norm: ", msa_mask_norm) + + # # raw notion + # # contact info + # # contact_info_input = contact_mask_input.astype(mstype.float16) + # # contact_feature = contact_info_input[..., None] * 10.0 # increase signal + # # contact_act = self.preprocess_contact(contact_feature) + # # pair_activations += contact_act + + sbr_act = self.sbr_act1(sbr*100) + sbr_act = self.sbr_act2(sbr_act) + + # # raw notion + # # sbr_act = self.sbr_act2(sbr_act) * self.sigmoid(self.sbr_gate(P.Concat(-1)((pair_activations, sbr_act)))) + # # sbr_act = self.process_sbr(sbr, sbr_mask) + # # print("debug msa_activations 477", msa_activations) + + # print("self.is_training:", self.is_training) + # print("msa_activations", msa_activations) + # print("pair_activations", pair_activations) + # print("msa_mask", msa_mask) + # print("msa_mask_norm", msa_mask_norm) + # print("mask_2d", mask_2d) + # print("sbr_act", sbr_act) + # print("sbr_mask", sbr_mask) + # print("interface_mask", interface_mask) + # print("self.idx_evoformer_block", self.idx_evoformer_block) + + # if self.is_training: + # for i in range(self.msa_stack_num): + # msa_activations, pair_activations = self.msa_stack[i](msa_activations, pair_activations, msa_mask, + # msa_mask_norm, mask_2d, sbr_act, sbr_mask, interface_mask) + # else: + self.idx_evoformer_block = self.idx_evoformer_block * 0 + while self.idx_evoformer_block < self.evoformer_num_block_eval: + msa_activations, pair_activations = self.msa_stack(msa_activations, + pair_activations, + msa_mask, + msa_mask_norm, + mask_2d, + sbr_act, + sbr_mask, + interface_mask, + self.idx_evoformer_block) + self.idx_evoformer_block += 1 + # print("msa_activations: ", msa_activations) + # print("pair_activations: ", pair_activations) + + # print("debug msa_activations 496", msa_activations) + # print("msa_activations:", msa_activations) + single_activations = self.single_activations(msa_activations[0]) + # print("single_activations:", single_activations) + num_sequences = msa_feat.shape[0] + # print("num_sequences:", num_sequences) + msa = msa_activations[:num_sequences, :, :] + # print("msa:", msa) + # msa_first_row = msa_activations[0] + msa_first_row = self.squeeze2( + self.slice_ops2(msa_activations, (0,0,0), (1, msa_activations.shape[1], msa_activations.shape[2]))) + # print("msa_first_row:", msa_first_row) + + # print("debug megafold single_activations", single_activations) + final_atom_positions, _, rp_structure_module, _, _, \ + _, _, _, _, _ = \ + self.structure_module(single_activations, + pair_activations, + seq_mask, + aatype, + sbr_act, + sbr_mask, + interface_mask, + residx_atom37_to_atom14, + atom37_atom_exists) + predicted_lddt_logits = self.module_lddt(rp_structure_module) + aligned_error_logits, aligned_error_breaks = self.aligned_error(pair_activations) + # if self.is_training and self.train_backward: + # predicted_lddt_logits = self.module_lddt(rp_structure_module) + # dist_logits, bin_edges = self.module_distogram(pair_activations) + # experimentally_logits = self.module_exp_resolved(single_activations) + # masked_logits = self.module_mask(msa) + # return dist_logits, bin_edges, experimentally_logits, masked_logits, aligned_error_logits, \ + # aligned_error_breaks, atom14_pred_positions, final_affines, angles_sin_cos_new, \ + # predicted_lddt_logits, structure_traj, sidechain_frames, sidechain_atom_pos, \ + # um_angles_sin_cos_new, final_atom_positions + final_atom_positions = P.Cast()(final_atom_positions, self._type) + # print("final_atom_positions: ", final_atom_positions) + prev_pos = final_atom_positions + prev_msa_first_row = msa_first_row + prev_pair = pair_activations + # if self.is_training: + # return prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits + # print("debug see confg diff", self.cfg) + # self.dump(f"256prev_pos_{rank}", prev_pos) + # self.dump(f"256prev_msa_first_row_{rank}", prev_msa_first_row) + # self.dump(f"256prev_pair_{rank}", prev_pair) + # self.dump(f"256predicted_lddt_logits_{rank}", predicted_lddt_logits) + # self.dump(f"256aligned_error_logits_{rank}", aligned_error_logits) + + prev_pos = self.allgather3(prev_pos, (0, 0, 0), (self.cfg.seq_length, 37, 3), (1, 1, 1)) + # prev_msa_first_row = self.allgather2(prev_msa_first_row, (0, 0), (self.cfg.seq_length, 256), (1, 1)) + # prev_pair = self.allgather3(prev_pair, (0, 0, 0), (self.cfg.seq_length, self.cfg.seq_length, 128), (1, 1, 1)) + predicted_lddt_logits = self.allgather2(predicted_lddt_logits, (0, 0), (self.cfg.seq_length, 50), (1, 1)) + aligned_error_logits = self.allgather3(aligned_error_logits, (0, 0, 0), (self.cfg.seq_length, self.cfg.seq_length, 64), (1, 1, 1)) + + # prev_msa_first_row = self.allgather(prev_msa_first_row) + # prev_pair = self.allgather(prev_pair) + # predicted_lddt_logits = self.allgather(predicted_lddt_logits) + # aligned_error_logits = self.allgather(aligned_error_logits) + # aligned_error_breaks = self.allgather(aligned_error_breaks) + # print("prev_pos: ", prev_pos) + # print("prev_msa_first_row: ", prev_msa_first_row) + # print("prev_pair: ", prev_pair) + # print("predicted_lddt_logits: ", predicted_lddt_logits) + # print("aligned_error_logits: ", aligned_error_logits) + # print("aligned_error_breaks: ", aligned_error_breaks) + + return prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits, aligned_error_logits, aligned_error_breaks \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/module/evoformer.py b/MindSPONGE/applications/research/Grasp/module/evoformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0652f548dcab019e5d91cea77d0719edc4f82a --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/evoformer.py @@ -0,0 +1,296 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evoformer""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindsponge1.cell import MSARowAttentionWithPairBias, Transition, OuterProductMean, \ + TriangleAttention, TriangleMultiplication, \ + MSAColumnGlobalAttention, MSAColumnAttention +# from mindspore import lazy_inline + + +class Evoformer(nn.Cell): + '''evoformer''' + # @lazy_inline + def __init__(self, config, msa_act_dim, pair_act_dim, is_extra_msa, batch_size): + super(Evoformer, self).__init__() + if is_extra_msa: + self.slice_cfg = config.slice.extra_msa_stack + else: + self.slice_cfg = config.slice.msa_stack + self.config = config + + self.msa_act = MsaAct(self.config, + self.slice_cfg, + msa_act_dim, + pair_act_dim, + is_extra_msa, + batch_size) + self.pair_act = PairAct(self.config, + self.slice_cfg, + msa_act_dim, + pair_act_dim, + batch_size) + + if config.is_training: + self.pair_act.recompute() + + def construct(self, msa_act, pair_act, msa_mask, extra_msa_norm, pair_mask, index=None): + '''construct''' + msa_act = self.msa_act(msa_act, pair_act, msa_mask, index) + pair_act = self.pair_act(msa_act, pair_act, msa_mask, extra_msa_norm, pair_mask, index) + return msa_act, pair_act + + +class MsaAct(nn.Cell): + """MsaAct""" + + def __init__(self, config, slice_cfg, msa_act_dim, pair_act_dim, is_extra_msa, batch_size): + super(MsaAct, self).__init__() + + self.slice_cfg = slice_cfg + self.config = config.evoformer + + self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( + self.config.msa_row_attention_with_pair_bias.num_head, + msa_act_dim, + self.config.msa_row_attention_with_pair_bias.gating, + msa_act_dim, + pair_act_dim, + batch_size, + self.slice_cfg.msa_row_attention_with_pair_bias) + self.msa_transition = Transition(self.config.msa_transition.num_intermediate_factor, + msa_act_dim, + batch_size, + self.slice_cfg.msa_transition) + if is_extra_msa: + self.attn_mod = MSAColumnGlobalAttention(self.config.msa_column_attention.num_head, + self.config.msa_column_attention.gating, + msa_act_dim, + batch_size, + self.slice_cfg.msa_column_global_attention) + else: + self.attn_mod = MSAColumnAttention(self.config.msa_column_attention.num_head, + msa_act_dim, + self.config.msa_column_attention.gating, + msa_act_dim, + batch_size, + self.slice_cfg.msa_column_attention) + + if config.is_training: + self.msa_row_attention_with_pair_bias.recompute() + self.attn_mod.recompute() + self.msa_transition.recompute() + + def construct(self, msa_act, pair_act, msa_mask, index=None): + '''construct''' + msa_act = P.Add()(msa_act, self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act, index)) + msa_act = P.Add()(msa_act, self.attn_mod(msa_act, msa_mask, index)) + msa_act = P.Add()(msa_act, self.msa_transition(msa_act, index)) + return msa_act + + +class PairAct(nn.Cell): + """PairAct""" + + def __init__(self, config, slice_cfg, msa_act_dim, pair_act_dim, batch_size): + super(PairAct, self).__init__() + self.slice_cfg = slice_cfg + self.config = config.evoformer + + self.outer_product_mean = OuterProductMean(self.config.outer_product_mean.num_outer_channel, + msa_act_dim, + pair_act_dim, + batch_size, + self.slice_cfg.outer_product_mean) + + self.triangle_attention_starting_node = TriangleAttention( + self.config.triangle_attention_starting_node.orientation, + self.config.triangle_attention_starting_node.num_head, + pair_act_dim, + self.config.triangle_attention_starting_node.gating, + pair_act_dim, + batch_size, + self.slice_cfg.triangle_attention_starting_node) + + self.triangle_attention_ending_node = TriangleAttention(self.config.triangle_attention_ending_node.orientation, + self.config.triangle_attention_ending_node.num_head, + pair_act_dim, + self.config.triangle_attention_ending_node.gating, + pair_act_dim, + batch_size, + self.slice_cfg.triangle_attention_ending_node) + + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + pair_act_dim, + batch_size, + self.slice_cfg.pair_transition) + + self.triangle_multiplication_outgoing = TriangleMultiplication( + self.config.triangle_multiplication_outgoing.num_intermediate_channel, + self.config.triangle_multiplication_outgoing.equation, + layer_norm_dim=pair_act_dim, + batch_size=batch_size) + + self.triangle_multiplication_incoming = TriangleMultiplication( + self.config.triangle_multiplication_incoming.num_intermediate_channel, + self.config.triangle_multiplication_incoming.equation, + layer_norm_dim=pair_act_dim, + batch_size=batch_size) + + def construct(self, msa_act, pair_act, msa_mask, extra_msa_norm, pair_mask, index=None): + '''construct''' + pair_act = P.Add()(pair_act, self.outer_product_mean(msa_act, msa_mask, extra_msa_norm, index)) + pair_act = P.Add()(pair_act, self.triangle_multiplication_outgoing(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_multiplication_incoming(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_attention_starting_node(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_attention_ending_node(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.pair_transition(pair_act, index)) + return pair_act + + +class MultimerEvoformer(nn.Cell): + '''multimerevoformer''' + # @lazy_inline + def __init__(self, config, msa_act_dim, pair_act_dim, is_extra_msa, batch_size, device_num): + super(MultimerEvoformer, self).__init__() + self.is_extra_msa = is_extra_msa + if is_extra_msa: + self.slice_cfg = config.slice.extra_msa_stack + else: + self.slice_cfg = config.slice.msa_stack + self.config = config.evoformer + + self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( + self.config.msa_row_attention_with_pair_bias.num_head, + msa_act_dim, + self.config.msa_row_attention_with_pair_bias.gating, + msa_act_dim, + pair_act_dim, + device_num, + batch_size, + self.slice_cfg.msa_row_attention_with_pair_bias, + is_extra_msa) + + # if not is_extra_msa: + # self.add_interface = AddInterface(msa_act_dim, batch_size) + # self.preprocess_sbr = PreprocessSBR(input_dim=128, output_dim=pair_act_dim, + # bais_and_relu=True, batch_size=batch_size) + + self.msa_transition = Transition(self.config.msa_transition.num_intermediate_factor, + msa_act_dim, + device_num, + batch_size, + self.slice_cfg.msa_transition) + + self.outer_product_mean = OuterProductMean(self.config.outer_product_mean.num_outer_channel, + msa_act_dim, + pair_act_dim, + device_num, + batch_size, + self.slice_cfg.outer_product_mean) + + self.triangle_attention_starting_node = TriangleAttention( + self.config.triangle_attention_starting_node.orientation, + self.config.triangle_attention_starting_node.num_head, + pair_act_dim, + self.config.triangle_attention_starting_node.gating, + pair_act_dim, + device_num, + batch_size, + self.slice_cfg.triangle_attention_starting_node) + + self.triangle_attention_ending_node = TriangleAttention(self.config.triangle_attention_ending_node.orientation, + self.config.triangle_attention_ending_node.num_head, + pair_act_dim, + self.config.triangle_attention_ending_node.gating, + pair_act_dim, + device_num, + batch_size, + self.slice_cfg.triangle_attention_ending_node) + + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + pair_act_dim, + device_num, + batch_size, + self.slice_cfg.pair_transition) + + self.triangle_multiplication_outgoing = TriangleMultiplication( + self.config.triangle_multiplication_outgoing.num_intermediate_channel, + self.config.triangle_multiplication_outgoing.equation, + pair_act_dim, + device_num, + batch_size=batch_size, + ) + + self.triangle_multiplication_incoming = TriangleMultiplication( + self.config.triangle_multiplication_incoming.num_intermediate_channel, + self.config.triangle_multiplication_incoming.equation, + pair_act_dim, + device_num, + batch_size=batch_size) + if is_extra_msa: + self.attn_mod = MSAColumnGlobalAttention(self.config.msa_column_attention.num_head, + self.config.msa_column_attention.gating, + msa_act_dim, + device_num, + batch_size, + self.slice_cfg.msa_column_global_attention) + else: + self.attn_mod = MSAColumnAttention(self.config.msa_column_attention.num_head, + msa_act_dim, + self.config.msa_column_attention.gating, + msa_act_dim, + device_num, + batch_size, + self.slice_cfg.msa_column_attention) + + if config.is_training: + # if not is_extra_msa: + # # self.add_interface.recompute() + # self.preprocess_sbr.recompute() + self.msa_row_attention_with_pair_bias.recompute() + self.attn_mod.recompute() + self.msa_transition.recompute() + self.triangle_multiplication_outgoing.recompute() + self.triangle_multiplication_incoming.recompute() + self.triangle_attention_starting_node.recompute() + self.triangle_attention_ending_node.recompute() + self.outer_product_mean.recompute() + self.pair_transition.recompute() + self.add2 = P.Add().shard(((1, device_num, 1), (1, device_num, 1))) + + def construct(self, msa_act, pair_act, msa_mask, extra_msa_norm, pair_mask, sbr_act=None, sbr_mask=None, interface_mask=None, index=None): + '''construct''' + pair_act = self.add2(pair_act, self.outer_product_mean(msa_act, msa_mask, extra_msa_norm, index)) + # raw notion + # if not self.is_extra_msa: + # msa_act = P.Add()(msa_act, self.add_interface(msa_act, interface_mask, index)) + + msa_act = self.add2(msa_act, self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act, sbr_act, sbr_mask, interface_mask, index)) + msa_act = self.add2(msa_act, self.attn_mod(msa_act, msa_mask, index)) + msa_act = self.add2(msa_act, self.msa_transition(msa_act, index)) + # print("This msa_act:", msa_act) + # raw notion + # if not self.is_extra_msa: + # pair_act = P.Add()(pair_act, self.preprocess_sbr(sbr_act, sbr_mask, index)) + + pair_act = self.add2(pair_act, self.triangle_multiplication_outgoing(pair_act, pair_mask, index)) + pair_act = self.add2(pair_act, self.triangle_multiplication_incoming(pair_act, pair_mask, index)) + pair_act = self.add2(pair_act, self.triangle_attention_starting_node(pair_act, pair_mask, index)) + pair_act = self.add2(pair_act, self.triangle_attention_ending_node(pair_act, pair_mask, index)) + pair_act = self.add2(pair_act, self.pair_transition(pair_act, index)) + return msa_act, pair_act \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/module/evogen_block.py b/MindSPONGE/applications/research/Grasp/module/evogen_block.py new file mode 100644 index 0000000000000000000000000000000000000000..9f6b5d3587be01b3f60183e77e0b6b6949fc49cc --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/evogen_block.py @@ -0,0 +1,660 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evogen block""" +import numpy as np +from mindspore import nn +from mindspore import Tensor +from mindspore import Parameter +import mindspore.ops as ops +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.initializer import initializer +import mindspore.common.dtype as mstype +import mindspore.nn.probability.distribution as msd +import mindspore.numpy as msnp + +from mindsponge1.cell import Attention, MSARowAttentionWithPairBias, Transition, \ + OuterProductMean, TriangleMultiplication, TriangleAttention +from mindsponge1.cell.initializer import lecun_init +from mindsponge1.cell.mask import MaskedLayerNorm + + +def absolute_position_embedding(length, depth, min_timescale=1, max_timescale=1e4): + '''absolute_position_embedding''' + depth = depth // 2 + positions = np.arange(length, dtype=np.float32) + log_timescale_increment = (np.log(max_timescale / min_timescale) / (depth - 1)) + inv_timescales = min_timescale * np.exp(np.arange(depth, dtype=np.float32) * -log_timescale_increment) + scaled_time = np.expand_dims(positions, 1) * np.expand_dims(inv_timescales, 0) + x = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) + return x + + +class EvoGenFeatProcess(nn.Cell): + '''EvoGenFeatProcess''' + + def __init__( + self, + config, + ): + super().__init__() + self.num_msa = config.global_config.num_msa + self.max_num_res = config.global_config.max_num_res + self.num_aa_types = config.global_config.num_aa_types + self.num_msa_types = self.num_aa_types + 1 + self.seq_weight_power = Tensor(config.train.seq_weight_power, mstype.float32) + self.rel_pos_generator = RelativePositionEmbedding(config.model.embeddings_and_evoformer) + self.label_smooting = config.train.label_smoothing + self.pi = np.pi + self.msa_onehot = nn.OneHot(depth=self.num_msa_types) + self.q_onehot = nn.OneHot(depth=self.num_aa_types) + + def construct(self, query_input, msa_input, additional_input, random_input, random_mask): + '''Transform input data into a set of labels and feats.''' + msa_ = msa_input[:, :, 0] + x_new_number = 7. * P.OnesLike()(msa_) + x_where = P.Equal()(msa_, 20) + msa_ = P.Select()(x_where, x_new_number, msa_) + msa_input = P.Concat(-1)((P.ExpandDims()(msa_, -1), msa_input[:, :, 1:])) + aa_labels = msa_input[:, :, 0] + + del_num = msa_input[:, :, 1] + del_num_feat = 2. / self.pi * msnp.arctan(del_num / 3.) + has_del_label = ops.clip_by_value(del_num, 0, 1) + + msa_mask = P.ExpandDims()(additional_input[:, 2], 0) + pair_mask = P.MatMul(transpose_a=True)(msa_mask, msa_mask) + context_mask = random_mask[:, 0] + context_mask[0] = 1. + context_mask = P.ExpandDims()(context_mask, 1) + target_mask = random_mask[:, 1] + target_mask[0] = 1. + target_mask = P.ExpandDims()(target_mask, 1) + norm_const_predict = P.ReduceSum()(msa_mask) + + if self.seq_weight_power > 1e-5: + norm_const_predict = norm_const_predict / 100. + norm_const_predict = P.Pow()(norm_const_predict, self.seq_weight_power) + else: + norm_const_predict = Tensor(1.0, mstype.float32) + + msa_raw_feat = msa_input[:, :, 0] + msa_raw_feat = P.Concat(0)((P.ExpandDims()(query_input[:, 0], 0), msa_raw_feat[1:, :])) + msa_raw_feat = self.msa_onehot(msa_raw_feat.astype(mstype.int32)) + msa_raw_feat = P.Concat(-1)((msa_raw_feat.astype(has_del_label.dtype), P.ExpandDims()(has_del_label, -1), + P.ExpandDims()(del_num_feat, -1))) + + q_raw_feat = self.q_onehot(query_input[:, 0].astype(mstype.int32)).astype(query_input.dtype) + + res_idx = additional_input[:, 1] + pair_raw_feat = self.rel_pos_generator(res_idx, res_idx) + + msa_labels_onehot = self.msa_onehot(msa_input[:, :, 0].astype(mstype.int32)) + + msa_mask_full = msa_input[:, :, 2] + + msa_mask_full = P.ExpandDims()(msa_mask_full, -1) + + msa_profile = P.ReduceSum()(msa_labels_onehot * msa_mask_full, 0) + num_seq = P.ReduceSum()(msa_mask_full[:, 0, 0]) + + msa_profile = msa_profile / (num_seq + 1e-5) + + q_labels = msa_labels_onehot[:1] + aa_labels_onehot = self.msa_onehot(aa_labels.astype(mstype.int32)) + aa_labels = P.Concat(0)((q_labels, aa_labels_onehot[1:])) + aa_labels = (1. - self.label_smooting) * aa_labels + self.label_smooting * (P.ExpandDims()(msa_profile, 0)) + + label_tuple = (aa_labels, del_num, has_del_label, del_num_feat, msa_profile, norm_const_predict) + feat_tuple = (q_raw_feat, msa_raw_feat, pair_raw_feat, msa_mask, pair_mask, context_mask, + target_mask, res_idx, random_input) + + return label_tuple, feat_tuple + + +class LatentNormal(nn.Cell): + '''LatentNormal''' + + def __init__(self): + super().__init__() + self.exp = F.exp + self.log = P.Log() + self.pi2 = Tensor(2 * np.pi, mstype.float32) + self.standard_normal = msd.Normal(mean=Tensor(0, dtype=mstype.float32), + sd=Tensor(1, dtype=mstype.float32), dtype=mstype.float32) + self.tanh = ops.Tanh() + + def sample(self, mu, log_sigma, temp=1.): + '''sample''' + mu, sigma = self._process_data(mu, log_sigma, temp) + eps = self.standard_normal.sample(mu.shape) + return eps * sigma + mu + + def sample_given_eps(self, eps, mu, log_sigma, temp=1.): + '''sample_given_eps''' + mu, sigma = self._process_data(mu, log_sigma, temp) + return eps * sigma + mu + + def construct(self, mu, log_sigma, normal_dist_mu, normal_dist_log_sigma): + '''kl''' + mu, sigma = self._process_data(mu, log_sigma) + normal_dist_mu, normal_dist_sigma = self._process_data(normal_dist_mu, normal_dist_log_sigma) + term1 = (mu - normal_dist_mu) / normal_dist_sigma + term2 = sigma / normal_dist_sigma + return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - F.log(term2) + + def _process_data(self, mu, log_sigma, temp=1.): + '''process_data''' + mu = 5. * self.tanh(mu / 5.) + log_sigma = 5. * self.tanh(log_sigma / 5.) + sigma = self.exp(log_sigma) + sigma *= temp + return mu, sigma + + +class RelativePositionEmbedding(nn.Cell): + '''RelativePositionEmbedding''' + + def __init__(self, + config, + ): + super(RelativePositionEmbedding, self).__init__() + + self.exact_distance = config.exact_distance + self.num_buckets = config.num_buckets + self.max_distance = config.max_distance + self.onehot = nn.OneHot(depth=2 * self.num_buckets + 1) + + @staticmethod + def _relative_position_bucket(x, alpha=16.0, beta=32.0, gamma=64.0): + '''_relative_position_bucket''' + alpha = Tensor(alpha, mstype.float32) + beta = Tensor(beta, mstype.float32) + gamma = Tensor(gamma, mstype.float32) + + scale = (beta - alpha) / F.log(gamma / alpha) + x_abs = P.Abs()(x) + gx = F.log((x_abs + 1e-3) / alpha) * scale + alpha + gx = P.Minimum()(beta, gx) + gx = P.Sign()(x) * gx + + cond = P.Greater()(x_abs, alpha) + ret = P.Select()(cond, gx, x) + ret = ops.clip_by_value(ret, -beta, beta) + + ret += beta + return ret + + def construct(self, q_idx, k_idx): + """ Compute binned relative position encoding """ + + context_position = P.ExpandDims()(q_idx, 1) + memory_position = P.ExpandDims()(k_idx, 0) + relative_position = memory_position - context_position + rp_bucket = self._relative_position_bucket(relative_position) + rp_onehot = self.onehot(rp_bucket.astype(mstype.int32)) + return rp_onehot + + +class EvogenAttention(Attention): + '''EvogenAttention''' + + def __init__(self, config, q_data_dim, m_data_dim, output_dim): + super(EvogenAttention, self).__init__(config.num_head, q_data_dim, + config.gating, q_data_dim, m_data_dim, output_dim, batch_size=None) + self.ape_table = config.ape_table + if self.ape_table is not None: + self.ape_table = Tensor(self.ape_table, mstype.float32) + self.onehot = nn.OneHot(depth=1024) + + def rope(self, hidden_states, res_idx): + '''rope''' + c_m = hidden_states.shape[-1] + n_res = res_idx.shape[0] + + idx_one_hot = self.onehot(res_idx.astype(mstype.int32)) + ape_sin, ape_cos = ops.Split(axis=-1, output_num=2)(self.ape_table) + ape_table = P.Concat(-1)([ape_cos, ape_sin]) + + rope = P.MatMul()(idx_one_hot, ape_table) + rope_double = P.Reshape()(P.Tile()(P.ExpandDims()(rope, -1), (1, 1, 2)), (n_res, -1)) + rope_cos, rope_sin = P.Split(axis=-1, output_num=2)(rope_double) + + vec_ = P.Reshape()(hidden_states, (-1, c_m // 2, 2)) + vec_even, vec_odd = P.Split(axis=-1, output_num=2)(vec_) + vec2 = P.Concat(axis=-1)([-vec_odd, vec_even]) + vec2 = P.Reshape()(vec2, hidden_states.shape) + + vec1 = P.Reshape()(hidden_states, (-1, n_res, c_m)) + vec2 = P.Reshape()(vec2, (-1, n_res, c_m)) + vec_rope = vec1 * P.ExpandDims()(rope_cos, 0) + \ + vec2 * P.ExpandDims()(rope_sin, 0) + + return P.Reshape()(vec_rope, hidden_states.shape) + + def construct(self, q_data, m_data, bias, pair_bias=None, res_idx=None): + '''construct''' + linear_gating_weight = 0 + if self.gating: + linear_gating_weight = self.linear_gating_weights + + b_dim, q_dim, a_dim = q_data.shape + _, k_dim, c_dim = m_data.shape + q_data = P.Reshape()(q_data, (-1, a_dim)) + m_data = P.Reshape()(m_data, (-1, c_dim)) + + q = self.matmul(q_data, self.linear_q_weights) * self.dim_per_head ** (-0.5) + k = self.matmul(m_data, self.linear_k_weights) + v = self.matmul(m_data, self.linear_v_weights) + + if (res_idx is not None) and (self.ape_table is not None): + q = self.rope(q, res_idx) + k = self.rope(k, res_idx) + + q = P.Reshape()(q, (b_dim, q_dim, self.num_head, -1)) + k = P.Reshape()(k, (b_dim, k_dim, self.num_head, -1)) + v = P.Reshape()(v, (b_dim, k_dim, self.num_head, -1)) + + tmp_q = P.Reshape()(P.Transpose()(q, (0, 2, 1, 3)), (b_dim * self.num_head, q_dim, -1)) + tmp_k = P.Reshape()(P.Transpose()(k, (0, 2, 1, 3)), (b_dim * self.num_head, k_dim, -1)) + logits = P.Add()(P.Reshape()(self.batch_matmul_trans_b(tmp_q, tmp_k), (b_dim, self.num_head, q_dim, k_dim)), + bias) + + if pair_bias is not None: + bias_ = P.ExpandDims()(pair_bias, 0) + logits = P.Add()(logits, bias_) + + probs = self.softmax(logits) + tmp_v = P.Reshape()(P.Transpose()(v, (0, 2, 3, 1)), (b_dim * self.num_head, -1, k_dim)) + tmp_probs = P.Reshape()(probs, (b_dim * self.num_head, q_dim, k_dim)) + + weighted_avg = P.Transpose()( + P.Reshape()(self.batch_matmul_trans_b(tmp_probs, tmp_v), (b_dim, self.num_head, q_dim, -1)), + (0, 2, 1, 3)) + + if self.gating: + gating_bias = P.ExpandDims()(P.ExpandDims()(self.gating_biases, 0), 0) + gate_values = P.Add()( + P.Reshape()(self.matmul(q_data, linear_gating_weight), (b_dim, q_dim, self.num_head, -1)), + gating_bias) + gate_values = gate_values + gate_values = self.sigmoid(gate_values) + gate_values = gate_values + weighted_avg = weighted_avg * gate_values + + weighted_avg = P.Reshape()(weighted_avg, (b_dim * q_dim, -1)) + output = P.Add()(P.Reshape()(self.matmul(weighted_avg, self.linear_output_weights), (b_dim, q_dim, -1)), + P.ExpandDims()(self.o_biases, 0)) + return output + + +class EvogenMSARowAttentionWithPairBias(MSARowAttentionWithPairBias): + '''EvogenMSARowAttentionWithPairBias''' + + def __init__(self, config, msa_act_dim, pair_act_dim): + super(EvogenMSARowAttentionWithPairBias, self).__init__(config.num_head, msa_act_dim, config.gating, + msa_act_dim, pair_act_dim) + self.config = config + self.attn_mod = EvogenAttention(self.config, msa_act_dim, msa_act_dim, msa_act_dim) + + def _compute(self, msa_act, bias, pair_bias=None, res_idx=None): + '''compute''' + msa_act = self.attn_mod(msa_act, msa_act, bias, pair_bias=pair_bias, res_idx=res_idx) + return msa_act + + +class MSAConditioner(nn.Cell): + '''MSAConditioner''' + + def __init__(self, config, layer_norm_dim): + super(MSAConditioner, self).__init__() + self.config = config + self.layer_norm_dim = layer_norm_dim + self.num_intermediate = int(layer_norm_dim * self.config.num_intermediate_factor) + self.act_fn = nn.ReLU() + self.matmul = P.MatMul(transpose_b=True) + self.sigmoid = nn.Sigmoid() + self.masked_layer_norm = MaskedLayerNorm() + self._init_parameter() + + def construct(self, act, mask): + '''construct''' + act_ = self.masked_layer_norm(act, self.input_layer_norm_gammas, self.input_layer_norm_betas, mask=mask) + q_act = P.ExpandDims()(act_[0], 0) + mix_act = P.Concat(-1)((P.Tile()(q_act, (act_.shape[0], 1, 1)), act_)) + act_shape = P.Shape()(mix_act) + if len(act_shape) != 2: + mix_act = P.Reshape()(mix_act, (-1, act_shape[-1])) + mix_act = self.act_fn(P.BiasAdd()(self.matmul(mix_act, self.transition_weights), self.transition_biases)) + gate_values = P.BiasAdd()(self.matmul(mix_act, self.linear_gating_weights), self.gating_biases) + gate_values = self.sigmoid(gate_values) + gate_values = P.Reshape()(gate_values, act.shape) + return act, gate_values + + def _init_parameter(self): + '''init_parameter''' + self.input_layer_norm_gammas = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + self.input_layer_norm_betas = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + self.transition_weights = Parameter(initializer(lecun_init(2 * self.layer_norm_dim, initializer_name='relu'), + [self.num_intermediate, 2 * self.layer_norm_dim])) + self.transition_biases = Parameter(Tensor(np.zeros((self.num_intermediate)), mstype.float32)) + self.linear_gating_weights = Parameter( + Tensor(np.zeros([self.layer_norm_dim, self.num_intermediate]), mstype.float32)) + self.gating_biases = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + + +class EvoformerSeqBlock(nn.Cell): + '''EvoformerSeqBlock''' + + def __init__(self, config, msa_act_dim, pair_act_dim, encoding=True): + super(EvoformerSeqBlock, self).__init__() + self.config = config + self.msa_row_attention_with_pair_bias = EvogenMSARowAttentionWithPairBias( + self.config.msa_row_attention_with_pair_bias, msa_act_dim, pair_act_dim) + self.msa_transition = Transition(self.config.msa_transition.num_intermediate_factor, msa_act_dim) + self.encoding = encoding + if self.encoding: + self.msa_conditioner = MSAConditioner(self.config.msa_condition, msa_act_dim) + + def construct(self, msa_act, pair_act, msa_mask, pair_mask, res_idx=None): + '''construct''' + msa_act = P.Add()(msa_act, + self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act, 0, + msa_mask, pair_mask, res_idx=res_idx)) + msa_act = P.Add()(msa_act, self.msa_transition(msa_act, 0, msa_mask)) + + if self.encoding: + act, gate_values = self.msa_conditioner(msa_act, msa_mask) + else: + act, gate_values = msa_act, 1. + msa_act = P.Add()(gate_values * act, (1. - gate_values) * P.ExpandDims()(act[0], 0)) + return msa_act + + +class EvoformerPairBlock(nn.Cell): + '''EvoformerPairBlock''' + + def __init__(self, config, msa_act_dim, pair_act_dim): + super(EvoformerPairBlock, self).__init__() + self.config = config + self.outer_product = OuterProductMean(self.config.outer_product.num_outer_channel, msa_act_dim, pair_act_dim) + self.triangle_multiplication_outgoing = TriangleMultiplication( + self.config.triangle_multiplication_outgoing.num_intermediate_channel, + self.config.triangle_multiplication_outgoing.equation, + pair_act_dim) + self.triangle_multiplication_incoming = TriangleMultiplication( + self.config.triangle_multiplication_incoming.num_intermediate_channel, + self.config.triangle_multiplication_incoming.equation, + pair_act_dim) + self.triangle_attention_starting_node = TriangleAttention( + self.config.triangle_attention_starting_node.orientation, + self.config.triangle_attention_starting_node.num_head, + pair_act_dim, + self.config.triangle_attention_starting_node.gating, + pair_act_dim) + self.triangle_attention_ending_node = TriangleAttention(self.config.triangle_attention_ending_node.orientation, + self.config.triangle_attention_ending_node.num_head, + pair_act_dim, + self.config.triangle_attention_ending_node.gating, + pair_act_dim) + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, pair_act_dim) + + def construct(self, msa_act, pair_act, msa_mask, pair_mask, context_mask, mask_norm=None): + '''construct''' + msa_mask_ = msa_mask * context_mask + pair_act = P.Add()(pair_act, self.outer_product(msa_act, msa_mask_, mask_norm)) + pair_act = P.Add()(pair_act, self.triangle_multiplication_outgoing(pair_act, pair_mask)) + pair_act = P.Add()(pair_act, self.triangle_multiplication_incoming(pair_act, pair_mask)) + + pair_act = P.Add()(pair_act, self.triangle_attention_starting_node(pair_act, pair_mask, mask=pair_mask)) + pair_mask_ = P.Transpose()(pair_mask, (1, 0)) + pair_act = P.Add()(pair_act, self.triangle_attention_ending_node(pair_act, pair_mask, mask=pair_mask_)) + pair_act = P.Add()(pair_act, self.pair_transition(pair_act, 0, pair_mask)) + return pair_act + + +class EvoformerIteration(nn.Cell): + '''EvoformerIteration''' + + def __init__(self, config, msa_act_dim, pair_act_dim, encoding=True): + super(EvoformerIteration, self).__init__() + self.config = config.model.embeddings_and_evoformer.evoformer + self.evoformer_seq_block = EvoformerSeqBlock(self.config, msa_act_dim, pair_act_dim, encoding=encoding) + if config.global_config.recompute: + self.evoformer_seq_block.recompute() + self.encoding = encoding + if self.encoding: + self.evoformer_pair_block = EvoformerPairBlock(self.config, msa_act_dim, pair_act_dim) + if config.global_config.recompute: + self.evoformer_pair_block.recompute() + + def construct(self, msa_act, pair_act, msa_mask, pair_mask, context_mask, mask_norm=None, res_idx=None): + '''construct''' + msa_act_ = msa_act + msa_act = self.evoformer_seq_block(msa_act_, pair_act, msa_mask, pair_mask, res_idx=res_idx) + if self.encoding: + pair_act = self.evoformer_pair_block(msa_act_, pair_act, msa_mask, pair_mask, context_mask, + mask_norm=mask_norm) + return msa_act, pair_act + + +class LatentTransition(nn.Cell): + '''LatentTransition''' + + def __init__(self, config, input_dim, output_dim): + super(LatentTransition, self).__init__() + self.config = config + self.layer_norm_dim = input_dim + self.num_intermediate = int(input_dim * self.config.num_intermediate_factor) + self.output_dim = output_dim + self.act_fn = nn.ReLU() + self.matmul = P.MatMul(transpose_b=True) + self.masked_layer_norm = MaskedLayerNorm() + self._init_parameter() + + def construct(self, act, mask): + '''construct''' + act = self.masked_layer_norm(act, self.input_layer_norm_gammas, self.input_layer_norm_betas, mask=mask) + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + act1 = P.BiasAdd()(self.matmul(act, self.linear0_weights), self.linear0_biases) + + act = self.act_fn(P.BiasAdd()(self.matmul(act, self.linear1_weights), self.linear1_biases)) + act = self.act_fn(P.BiasAdd()(self.matmul(act, self.linear2_weights), self.linear2_biases)) + act = P.BiasAdd()(self.matmul(act, self.linear3_weights), self.linear3_biases) + + act = P.Add()(act, act1) + act = P.Reshape()(act, act_shape[:-1] + (-1,)) + return act + + def _init_parameter(self): + '''init parameter''' + self.input_layer_norm_gammas = Parameter(Tensor(np.ones((self.layer_norm_dim)), mstype.float32)) + self.input_layer_norm_betas = Parameter(Tensor(np.zeros((self.layer_norm_dim)), mstype.float32)) + + self.linear0_weights = Parameter( + initializer(lecun_init(self.layer_norm_dim), [self.output_dim, self.layer_norm_dim])) + self.linear0_biases = Parameter(Tensor(np.zeros((self.output_dim)), mstype.float32)) + + self.linear1_weights = Parameter(initializer(lecun_init(self.layer_norm_dim, initializer_name='relu'), + [self.num_intermediate, self.layer_norm_dim])) + self.linear1_biases = Parameter(Tensor(np.zeros((self.num_intermediate)), mstype.float32)) + self.linear2_weights = Parameter(initializer(lecun_init(self.layer_norm_dim, initializer_name='relu'), + [self.num_intermediate, self.layer_norm_dim])) + self.linear2_biases = Parameter(Tensor(np.zeros((self.num_intermediate)), mstype.float32)) + + self.linear3_weights = Parameter(Tensor(np.zeros((self.output_dim, self.layer_norm_dim)), mstype.float32)) + self.linear3_biases = Parameter(Tensor(np.zeros((self.output_dim)), mstype.float32)) + + +class ColumnAttentionWithPairBias(nn.Cell): + ''''ColumnAttentionWithPairBias''' + + def __init__(self, config, input_dim, output_dim): + super(ColumnAttentionWithPairBias, self).__init__() + self.attn_mod = EvogenAttention(config, input_dim, input_dim, output_dim) + self.input_norm_gammas = Parameter(Tensor(np.ones([input_dim]), mstype.float32)) + self.input_norm_betas = Parameter(Tensor(np.zeros([input_dim]), mstype.float32)) + self.masked_layer_norm = MaskedLayerNorm() + + def construct(self, q, k, q_mask, k_mask): + '''construct''' + q_act = P.Transpose()(q, (1, 0, 2)) + k_act = P.Transpose()(k, (1, 0, 2)) + q_act = self.masked_layer_norm(q_act, self.input_norm_gammas, self.input_norm_betas, mask=q_mask) + k_act = self.masked_layer_norm(k_act, self.input_norm_gammas, self.input_norm_betas, mask=k_mask) + + bias = 1e9 * (k_mask - 1.0) + bias = P.ExpandDims()(P.ExpandDims()(bias, 1), 2) + act = self.attn_mod(q_act, k_act, bias) + act = P.Transpose()(act, (1, 0, 2)) + return act + + +class LatentTransformerBlock(nn.Cell): + '''LatentTransformerBlock''' + + def __init__(self, config, input_dim, output_dim): + super(LatentTransformerBlock, self).__init__() + self.column_attention_with_pair_bias = ColumnAttentionWithPairBias( + config.column_attention_with_pair_bias, input_dim, output_dim) + self.transition = Transition(config.msa_transition.num_intermediate_factor, output_dim) + + def construct(self, q_act, k_act, q_mask, k_mask): + '''construct''' + act = P.Add()(q_act, self.column_attention_with_pair_bias(q_act, k_act, q_mask, k_mask)) + q_mask_t = P.Transpose()(q_mask, (1, 0)) + act = P.Add()(act, self.transition(act, 0, q_mask_t)) + return act + + +class LatentStatistics(nn.Cell): + '''LatentStatistics''' + + def __init__(self, config, latent_dim): + super(LatentStatistics, self).__init__() + self.num_intermediate = int(latent_dim * config.num_intermediate_factor) + self.act_fn = nn.ReLU() + self.matmul = P.MatMul(transpose_b=True) + self.split = ops.Split(axis=-1, output_num=2) + self.prior_net1_weights = Parameter( + initializer(lecun_init(latent_dim, initializer_name='relu'), [self.num_intermediate, latent_dim])) + self.prior_net1_biases = Parameter(Tensor(np.zeros((self.num_intermediate)), mstype.float32)) + self.prior_net2_weights = Parameter( + Tensor(np.zeros((2 * latent_dim, self.num_intermediate)), mstype.float32)) + self.prior_net2_biases = Parameter(Tensor(np.zeros((2 * latent_dim)), mstype.float32)) + + def construct(self, w_act, v_act): + '''construct''' + act_shape = P.Shape()(w_act) + if len(act_shape) != 2: + w_act = P.Reshape()(w_act, (-1, act_shape[-1])) + prior_state = self.act_fn(P.BiasAdd()(self.matmul(w_act, self.prior_net1_weights), self.prior_net1_biases)) + prior_state = P.BiasAdd()(self.matmul(prior_state, self.prior_net2_weights), self.prior_net2_biases) + prior_state = P.Reshape()(prior_state, act_shape[:-1] + (-1,)) + mu_prior, log_sigma_prior = self.split(prior_state) + + act_shape = P.Shape()(v_act) + if len(act_shape) != 2: + v_act = P.Reshape()(v_act, (-1, act_shape[-1])) + posterior_state = self.act_fn(P.BiasAdd()(self.matmul(v_act, self.prior_net1_weights), self.prior_net1_biases)) + posterior_state = P.BiasAdd()(self.matmul(posterior_state, self.prior_net2_weights), self.prior_net2_biases) + posterior_state = P.Reshape()(posterior_state, act_shape[:-1] + (-1,)) + mu_posterior, log_sigma_posterior = self.split(posterior_state) + latent_statistics_result = mu_prior, log_sigma_prior, mu_posterior, log_sigma_posterior + return latent_statistics_result + + +class LatentRemap(nn.Cell): + '''LatentRemap''' + + def __init__(self, config, input_dim, output_dim): + super(LatentRemap, self).__init__() + self.transition = Transition(config.msa_transition.num_intermediate_factor, output_dim) + self.matmul = P.MatMul(transpose_b=True) + self.linear_weights = Parameter(initializer(lecun_init(input_dim), [output_dim, input_dim])) + self.linear_biases = Parameter(Tensor(np.zeros((output_dim)), mstype.float32)) + + def construct(self, act, h_act, mask): + '''construct''' + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + act = P.BiasAdd()(self.matmul(act, self.linear_weights), self.linear_biases) + h_act_star = P.Reshape()(act, act_shape[:-1] + (-1,)) + delta_h = h_act_star - h_act + delta_h = P.Add()(delta_h, self.transition(delta_h, 0, mask)) + return delta_h + + +class LatentBlock(nn.Cell): + '''LatentBlock''' + + def __init__(self, config, msa_dim, latent_dim): + super(LatentBlock, self).__init__() + self.config = config.model.latent + self.temperature = self.config.temperature + self.encoder_latent_projection = LatentTransition(self.config.latent_transition, msa_dim, latent_dim) + self.decoder_latent_projection = LatentTransition(self.config.latent_transition, msa_dim, latent_dim) + self.context_transformer_layers = self.config.context_layers + blocks = nn.CellList() + for _ in range(self.context_transformer_layers): + block = LatentTransformerBlock(self.config, latent_dim, latent_dim) + if config.global_config.recompute: + block.recompute() + blocks.append(block) + self.context_transformer = blocks + self.match_transformer = LatentTransformerBlock(self.config, latent_dim, latent_dim) + if config.global_config.recompute: + self.match_transformer.recompute() + + self.noise_transformer = LatentTransformerBlock(self.config, latent_dim, latent_dim) + if config.global_config.recompute: + self.noise_transformer.recompute() + + self.latent_statistics = LatentStatistics(self.config.latent_statistics, latent_dim) + self.latent_normal = LatentNormal() + self.latent_mapper = LatentRemap(self.config, latent_dim, msa_dim) + + def construct(self, dec_act, enc_act, msa_mask, context_mask, target_mask, eps=None): + '''construct''' + q_mask_u = P.Reshape()(context_mask, (1, -1)) + q_mask_w = P.Reshape()(target_mask, (1, -1)) + + u_act = self.encoder_latent_projection(enc_act, msa_mask) + w_act = self.decoder_latent_projection(dec_act, msa_mask) + u_act_star = u_act + for i in range(self.context_transformer_layers): + u_act_star = self.context_transformer[i](u_act, u_act, q_mask_u, q_mask_u) + + w_act_star = self.match_transformer(w_act, u_act_star, q_mask_w, q_mask_u) + v_act_star = self.match_transformer(u_act, u_act_star, q_mask_w, q_mask_u) + mu_prior, log_sigma_prior, mu_posterior, log_sigma_posterior = self.latent_statistics(w_act_star, v_act_star) + target_mask = P.Reshape()(target_mask, (-1, 1, 1)) + + mu_posterior = target_mask * mu_posterior + (1. - target_mask) * mu_prior + log_sigma_posterior = target_mask * log_sigma_posterior + (1. - target_mask) * log_sigma_prior + if eps is not None: + eps[0] *= 0. + z_act = self.latent_normal.sample_given_eps(eps, mu_posterior, log_sigma_posterior, temp=self.temperature) + else: + z_act = self.latent_normal.sample(mu_posterior, log_sigma_posterior, temp=self.temperature) + + z_act_star = self.noise_transformer(z_act, u_act_star, q_mask_w, q_mask_u) + delta_h = self.latent_mapper(z_act_star, dec_act, msa_mask) + dec_act = P.Add()(dec_act, delta_h) + latent_block_result = dec_act, mu_prior, log_sigma_prior, mu_posterior, log_sigma_posterior + return latent_block_result diff --git a/MindSPONGE/applications/research/Grasp/module/fold_wrapcell.py b/MindSPONGE/applications/research/Grasp/module/fold_wrapcell.py new file mode 100644 index 0000000000000000000000000000000000000000..007f995caab5b6cec5b84e913873af2ed4c4868c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/fold_wrapcell.py @@ -0,0 +1,212 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""warp cell""" + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import ops +from mindspore.context import ParallelMode +from mindspore.nn import DistributedGradReducer +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.parallel._utils import _get_device_num +from mindspore.parallel._utils import (_get_gradients_mean, _get_parallel_mode) +from module.loss_module import LossNet + +GRADIENT_CLIP_TYPE = 1 + +clip_grad = ops.MultitypeFuncGraph("clip_grad") + + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """_clip_grad""" + if clip_type not in (0, 1): + return grad + dt = ops.dtype(grad) + if clip_type == 0: + new_grad = ops.clip_by_value(grad, ops.cast(ops.tuple_to_array((-clip_value,)), dt), + ops.cast(ops.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt)) + return new_grad + + +grad_scale = C.MultitypeFuncGraph("grad_scale") + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + """tensor_grad_scale""" + return grad * ops.Reciprocal()(scale) + + +class TrainOneStepCell(nn.Cell): + """TrainOneStepCell""" + def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True, use_global_norm=True, + gradient_clip_value=1.0, train_fold=True): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.optimizer = optimizer + self.weights = self.optimizer.parameters + self.grad = ops.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.enable_clip_grad = enable_clip_grad + self.hyper_map = ops.HyperMap() + self.use_global_norm = use_global_norm + self.gradient_clip_value = gradient_clip_value + self.train_fold = train_fold + + self.reducer_flag = False + self.grad_reducer = F.identity + self.parallel_mode = _get_parallel_mode() + self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) + if self.reducer_flag: + self.mean = _get_gradients_mean() + self.degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) + + def construct(self, *inputs): + """construct""" + # loss, l_fape_side, l_fape_backbone, l_anglenorm, distogram_loss, masked_loss, predict_lddt_loss,\ + # structure_violation_loss, no_clamp, fape_nc_intra, fape_nc_inter, chain_centre_mass_loss, aligned_error_loss,\ + # sbr_inter_fape_loss, sbr_inter_drmsd_loss, sbr_inter_disto_loss,\ + # sbr_intra_fape_loss, sbr_intra_drmsd_loss, sbr_intra_disto_loss, interface_loss, \ + # recall_intra, recall_inter, recall_interface, perfect_recall_interface, recall_inter1, recall_intra1 + + if self.train_backward: + loss_all = self.network(*inputs) + grads = None + loss, l_fape_side, l_fape_backbone, l_anglenorm, \ + distogram_loss, masked_loss, predict_lddt_loss,\ + structure_violation_loss, no_clamp, fape_nc_intra, fape_nc_inter, \ + chain_centre_mass_loss, aligned_error_loss, \ + sbr_inter_fape_loss, sbr_inter_drmsd_loss, sbr_inter_disto_loss,\ + sbr_intra_fape_loss, sbr_intra_drmsd_loss, sbr_intra_disto_loss, interface_loss, \ + recall_intra, recall_inter, recall_interface, perfect_recall_interface, recall_inter1, recall_intra1 = loss_all + + sens = F.fill(loss.dtype, loss.shape, self.sens) + sens1 = F.fill(l_fape_side.dtype, l_fape_side.shape, 0.0) + sens2 = F.fill(l_fape_backbone.dtype, l_fape_backbone.shape, 0.0) + sens3 = F.fill(l_anglenorm.dtype, l_anglenorm.shape, 0.0) + sens4 = F.fill(distogram_loss.dtype, distogram_loss.shape, 0.0) + sens5 = F.fill(masked_loss.dtype, masked_loss.shape, 0.0) + sens6 = F.fill(predict_lddt_loss.dtype, predict_lddt_loss.shape, 0.0) + sens7 = F.fill(structure_violation_loss.dtype, structure_violation_loss.shape, 0.0) + sens8 = F.fill(no_clamp.dtype, no_clamp.shape, 0.0) + sens9 = F.fill(fape_nc_intra.dtype, fape_nc_intra.shape, 0.0) + sens10 = F.fill(fape_nc_inter.dtype, fape_nc_inter.shape, 0.0) + sens11 = F.fill(chain_centre_mass_loss.dtype, chain_centre_mass_loss.shape, 0.0) + sens12 = F.fill(aligned_error_loss.dtype, aligned_error_loss.shape, 0.0) + sens13 = F.fill(sbr_inter_fape_loss.dtype, sbr_inter_fape_loss.shape, 0.0) + sens14 = F.fill(sbr_inter_drmsd_loss.dtype, sbr_inter_drmsd_loss.shape, 0.0) + sens15 = F.fill(sbr_inter_disto_loss.dtype, sbr_inter_disto_loss.shape, 0.0) + sens16 = F.fill(sbr_intra_fape_loss.dtype, sbr_intra_fape_loss.shape, 0.0) + sens17 = F.fill(sbr_intra_drmsd_loss.dtype, sbr_intra_drmsd_loss.shape, 0.0) + sens18 = F.fill(sbr_intra_disto_loss.dtype, sbr_intra_disto_loss.shape, 0.0) + sens19 = F.fill(interface_loss.dtype, interface_loss.shape, 0.0) + sens20 = F.fill(recall_intra.dtype, recall_intra.shape, 0.0) + sens21 = F.fill(recall_inter.dtype, recall_inter.shape, 0.0) + sens22 = F.fill(recall_interface.dtype, recall_interface.shape, 0.0) + sens23 = F.fill(perfect_recall_interface.dtype, perfect_recall_interface.shape, 0.0) + sens24 = F.fill(recall_inter1.dtype, recall_inter1.shape, 0.0) + sens25 = F.fill(recall_intra1.dtype, recall_intra1.shape, 0.0) + + grads = self.grad(self.network, self.weights)(*inputs, (sens, sens1, sens2, sens3, sens4, sens5, sens6,\ + sens7, sens8, sens9, sens10, sens11, sens12, sens13, sens14, sens15, sens16, sens17, sens18, sens19,\ + sens20, sens21, sens22, sens23, sens24, sens25)) + + grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_tensor(self.sens)), grads) + grads = self.grad_reducer(grads) + if self.enable_clip_grad: + if self.use_global_norm: + grads = C.clip_by_global_norm(grads, self.gradient_clip_value) + else: + grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, self.gradient_clip_value), grads) + + loss_all = F.depend(loss_all, self.optimizer(grads)) + + return loss_all + + out = self.network(*inputs) + return out + + +class WithLossCell(nn.Cell): + """WithLossCell""" + def __init__(self, backbone, config): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self.loss_net = LossNet(config).to_float(mstype.float32) + + # def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, + # template_aatype, template_all_atom_masks, template_all_atom_positions, + # template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + # extra_deletion_value, extra_msa_mask, + # residx_atom37_to_atom14, atom37_atom_exists, residue_index, + # prev_pos, prev_msa_first_row, prev_pair, pseudo_beta_gt, pseudo_beta_mask_gt, + # all_atom_mask_gt, true_msa, bert_mask, + # residx_atom14_to_atom37, restype_atom14_bond_lower_bound, restype_atom14_bond_upper_bound, + # atomtype_radius, backbone_affine_tensor, backbone_affine_mask, + # atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, + # atom14_atom_exists, atom14_alt_gt_exists, all_atom_positions, rigidgroups_gt_frames, + # rigidgroups_gt_exists, rigidgroups_alt_gt_frames, torsion_angles_sin_cos_gt, use_clamped_fape, + # filter_by_solution, chi_mask): + def construct(self, aatype, residue_index, template_aatype, template_all_atom_masks, template_all_atom_positions, + asym_id, sym_id, entity_id, seq_mask, msa_mask, target_feat, msa_feat, + extra_msa, extra_msa_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, + sbr, sbr_mask, interface_mask, + prev_pos, prev_msa_first_row, prev_pair, + pseudo_beta, pseudo_beta_mask, residx_atom14_to_atom37, + backbone_affine_tensor, backbone_affine_mask, rigidgroups_gt_frames, + rigidgroups_gt_exists, rigidgroups_alt_gt_frames, torsion_angles_sin_cos, chi_mask, + atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, + atom14_atom_exists, atom14_alt_gt_exists, all_atom_positions, all_atom_mask, + true_msa, bert_mask, + restype_atom14_bond_lower_bound,restype_atom14_bond_upper_bound,atomtype_radius, + use_clamped_fape, filter_by_solution, asym_mask): + """construct""" + if self.train_backward: + dist_logits, bin_edges, experimentally_logits, masked_logits, aligned_error_logits, aligned_error_breaks, \ + atom14_pred_positions, final_affines, angles_sin_cos_new, predicted_lddt_logits, structure_traj, \ + sidechain_frames, sidechain_atom_pos, um_angles_sin_cos_new, final_atom_positions = \ + self._backbone(aatype, residue_index, template_aatype, template_all_atom_masks, template_all_atom_positions, + asym_id, sym_id, entity_id, seq_mask, msa_mask, target_feat, msa_feat, + extra_msa, extra_msa_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, + sbr, sbr_mask, interface_mask, prev_pos, prev_msa_first_row, prev_pair) + out = self.loss_net(dist_logits, bin_edges, pseudo_beta, pseudo_beta_mask, + experimentally_logits, atom37_atom_exists, all_atom_mask, true_msa, + masked_logits, bert_mask, atom14_pred_positions, residue_index, aatype, + residx_atom14_to_atom37, restype_atom14_bond_lower_bound, + restype_atom14_bond_upper_bound, seq_mask, atomtype_radius, final_affines, + aligned_error_breaks, aligned_error_logits, angles_sin_cos_new, + um_angles_sin_cos_new, backbone_affine_tensor, backbone_affine_mask, + atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, + atom14_gt_exists, atom14_atom_exists, atom14_alt_gt_exists, + final_atom_positions, all_atom_positions, predicted_lddt_logits, + structure_traj, rigidgroups_gt_frames, rigidgroups_gt_exists, + rigidgroups_alt_gt_frames, + sidechain_frames, sidechain_atom_pos, torsion_angles_sin_cos, + chi_mask, use_clamped_fape, filter_by_solution, asym_id, asym_mask, + sbr, sbr_mask, interface_mask) + else: + out = self._backbone(aatype, residue_index, template_aatype, template_all_atom_masks, template_all_atom_positions, + asym_id, sym_id, entity_id, seq_mask, msa_mask, target_feat, msa_feat, + extra_msa, extra_msa_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, sbr, sbr_mask, interface_mask, prev_pos, prev_msa_first_row, prev_pair) + return out diff --git a/MindSPONGE/applications/research/Grasp/module/head.py b/MindSPONGE/applications/research/Grasp/module/head.py new file mode 100644 index 0000000000000000000000000000000000000000..f81e8eb4cf309736f1193fe624c3c6bf805c22d4 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/head.py @@ -0,0 +1,276 @@ +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore import Tensor +from mindspore.ops import functional as F +from mindsponge1.cell.initializer import lecun_init + + +class PredictedLDDTHead(nn.Cell): + """Head to predict the per-residue LDDT to be used as a confidence measure.""" + + def __init__(self, config, seq_channel): + super().__init__() + self.config = config + self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5) + self.act_0 = nn.Dense(seq_channel, self.config.num_channels, + weight_init=lecun_init(seq_channel, initializer_name='relu') + ).to_float(mstype.float16) + self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, + weight_init=lecun_init(self.config.num_channels, initializer_name='relu') + ).to_float(mstype.float16) + self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros' + ).to_float(mstype.float16) + self.relu = nn.ReLU() + + def construct(self, rp_structure_module): + """Builds ExperimentallyResolvedHead module.""" + act = rp_structure_module + act = self.input_layer_norm(act.astype(mstype.float32)) + act = self.act_0(act) + act = self.relu(act.astype(mstype.float32)) + act = self.act_1(act) + act = self.relu(act.astype(mstype.float32)) + logits = self.logits(act) + return logits + + +class DistogramHead(nn.Cell): + """Head to predict a distogram. + + Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" + """ + + def __init__(self, config, pair_dim): + super().__init__() + self.config = config + self.half_logits = nn.Dense(pair_dim, self.config.num_bins, weight_init='zeros') + self.first_break = self.config.first_break + self.last_break = self.config.last_break + self.num_bins = self.config.num_bins + + def construct(self, pair): + """Builds DistogramHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + + Returns: + Dictionary containing: + * logits: logits for distogram, shape [N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1,]. + """ + half_logits = self.half_logits(pair) + + logits = half_logits + mnp.swapaxes(half_logits, -2, -3) + breaks = mnp.linspace(self.first_break, self.last_break, self.num_bins - 1) + + return logits, breaks + + +# class DistogramHEAD_sbr(nn.Cell): +# """Head to predict a distogram. +# """ + +# def __init__(self, pair_dim): +# super().__init__() +# self.first_break = 8 +# self.last_break = 24 +# self.num_bins = 10 +# self.half_logits = nn.Dense(pair_dim, self.num_bins, weight_init='zeros') + +# def construct(self, pair): +# """Builds DistogramHead module. + +# Arguments: +# representations: Dictionary of representations, must contain: +# * 'pair': pair representation, shape [N_res, N_res, c_z]. + +# Returns: +# Dictionary containing: +# * logits: logits for distogram, shape [N_res, N_res, N_bins]. +# * bin_breaks: array containing bin breaks, shape [N_bins - 1,]. +# """ +# half_logits = self.half_logits(pair) + +# logits = half_logits + mnp.swapaxes(half_logits, -2, -3) +# breaks = mnp.linspace(self.first_break, self.last_break, self.num_bins - 1) + +# return logits, breaks + + +class ExperimentallyResolvedHead(nn.Cell): + """Predicts if an atom is experimentally resolved in a high-res structure. + + Only trained on high-resolution X-ray crystals & cryo-EM. + Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' + """ + + def __init__(self, seq_channel): + super().__init__() + self.logits = nn.Dense(seq_channel, 37, weight_init='zeros') + + def construct(self, single): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'single': Single representation, shape [N_res, c_s]. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_res, 37], + log probability that an atom is resolved in atom37 representation, + can be converted to probability by applying sigmoid. + """ + logits = self.logits(single) + return logits + + +class MaskedMsaHead(nn.Cell): + """Head to predict MSA at the masked locations. + + The MaskedMsaHead employs a BERT-style objective to reconstruct a masked + version of the full MSA, based on a linear projection of + the MSA representation. + Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" + """ + + def __init__(self, config, msa_channel): + super().__init__() + self.config = config + self.logits = nn.Dense(msa_channel, self.config.num_output, weight_init='zeros') + + def construct(self, msa): + """Builds MaskedMsaHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'msa': MSA representation, shape [N_seq, N_res, c_m]. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_seq, N_res, N_aatype] with + (unnormalized) log probabilies of predicted aatype at position. + """ + # del batch + logits = self.logits(msa) + return logits + + +class PredictedAlignedErrorHead(nn.Cell): + """Head to predict the distance errors in the backbone alignment frames. + + Can be used to compute predicted TM-Score. + Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" + """ + + def __init__(self, config, pair_dim): + super().__init__() + self.config = config + self.num_bins = self.config.num_bins + self.max_error_bin = self.config.max_error_bin + # self.min_error_bin = self.config.min_error_bin + self.logits = nn.Dense(pair_dim, self.num_bins, weight_init='zeros') + + def construct(self, pair): + """Builds PredictedAlignedErrorHead module. + + Arguments: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + + Returns: + * logits: logits for aligned error, shape [N_res, N_res, N_bins]. + * breaks: array containing bin breaks, shape [N_bins - 1]. + """ + logits = self.logits(pair) + breaks = mnp.linspace(0, self.max_error_bin, self.num_bins - 1) + return logits, breaks + + +class EstogramHead(nn.Cell): + """Head to predict estogram.""" + + def __init__(self, first_break, last_break, num_bins): + super().__init__() + self.first_break = first_break + self.last_break = last_break + self.num_bins = num_bins + + self.breaks = mnp.linspace(self.first_break, self.last_break, self.num_bins) + self.width = self.breaks[1] - self.breaks[0] + + self.centers = self.breaks + 0.5 * self.width + + self.softmax = nn.Softmax(-1) + self.zero = Tensor([0.]) + + def compute_estogram(self, distogram_logits, decoy_distance_mat): + """compute estogram matrix. + Arguments: + distogram_logits: [N_res, N_res, N_bins]. + decoy_distance_mat: [N_res, N_res] + Returns: + estogram: shape [N_res, N_res, N_bins]. + esto_centers: shape [N_res, N_res, N_bins]. + """ + square_centers = mnp.reshape(self.centers, (1, 1, -1)) + estogram = self.softmax(distogram_logits) + esto_centers = square_centers - mnp.expand_dims(decoy_distance_mat, -1) + return estogram, esto_centers + + def construct(self, distogram_logits, pseudo_beta, pseudo_beta_mask, cutoff=15.): + """construct""" + positions = pseudo_beta + pad_mask = mnp.expand_dims(pseudo_beta_mask, 1) + pad_mask_2d = pad_mask * mnp.transpose(pad_mask, (1, 0)) + pad_mask_2d *= (1. - mnp.eye(pad_mask_2d.shape[1])) + + dist_xyz = mnp.square(mnp.expand_dims(positions, axis=1) - mnp.expand_dims(positions, axis=0)) + dmat_decoy = mnp.sqrt(1e-10 + mnp.sum(dist_xyz.astype(mstype.float32), -1)) + + estogram, esto_centers = self.compute_estogram(distogram_logits, dmat_decoy) + pair_errors = mnp.sum(estogram * esto_centers, -1) + + p1 = self._integrate(distogram_logits, mnp.abs(esto_centers) < 0.5).astype(mnp.float32) + p2 = self._integrate(distogram_logits, mnp.abs(esto_centers) < 1.0).astype(mnp.float32) + p3 = self._integrate(distogram_logits, mnp.abs(esto_centers) < 2.0).astype(mnp.float32) + p4 = self._integrate(distogram_logits, mnp.abs(esto_centers) < 4.0).astype(mnp.float32) + + p0 = self._integrate(distogram_logits, self.centers < cutoff).astype(mnp.float32) + pred_mask2d = p0 * pad_mask_2d + + norm = mnp.sum(pred_mask2d, -1) + 1e-6 + p1 = mnp.sum(p1 * pred_mask2d, -1) + p2 = mnp.sum(p2 * pred_mask2d, -1) + p3 = mnp.sum(p3 * pred_mask2d, -1) + p4 = mnp.sum(p4 * pred_mask2d, -1) + + plddt = 0.25 * (p1 + p2 + p3 + p4) / norm + + return plddt, pred_mask2d, pair_errors + + def _integrate(self, distogram_logits, integrate_masks): + """compute estogram matrix. + Arguments: + distogram_logits: [N_res, N_res, N_bins]. + integrate_masks: [N_res, N_res, N_bins] + Returns: + v: shape [N_res, N_res]. + """ + probs = self.softmax(distogram_logits) + integrate_masks = F.cast(integrate_masks, mnp.float32) + v = mnp.sum(probs * integrate_masks, -1) + return v diff --git a/MindSPONGE/applications/research/Grasp/module/loss_module.py b/MindSPONGE/applications/research/Grasp/module/loss_module.py new file mode 100644 index 0000000000000000000000000000000000000000..c0049ab2d6a12cf22b61a99724bf5c938c7fb688 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/loss_module.py @@ -0,0 +1,495 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""loss module""" + +import mindspore as ms +import mindspore.communication.management as D +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore import Tensor +from mindspore.context import ParallelMode +from mindspore.parallel._utils import _get_parallel_mode +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindsponge1.common import residue_constants +from mindsponge1.common.utils import pseudo_beta_fn +from mindsponge1.common.geometry import invert_point, quaternion_from_tensor, vecs_expand_dims +from mindsponge1.metrics.structure_violations import get_structural_violations, compute_renamed_ground_truth, backbone, \ + sidechain, supervised_chi, local_distance_difference_test +from mindsponge1.metrics import BalancedMSE, BinaryFocal, MultiClassFocal +# from restraint_sample import BINS +from restraint_sample import BINS + + +class LossNet(nn.Cell): + """loss net""" + + def __init__(self, config, train_fold=True): + super(LossNet, self).__init__() + self.config = config + self.num_res = config.seq_length + self.num_bins = config.heads.distogram.num_bins + self.resolution = config.heads.resolution + self.distogram_weight = config.heads.distogram.weight + self.distogram_one_hot = nn.OneHot(depth=self.num_bins, axis=-1) + self.distogram_one_hot_sbr = nn.OneHot(depth=len(BINS)+1, axis=-1) + self.exp_min_resolution = config.heads.experimentally_resolved.min_resolution + self.exp_max_resolution = config.heads.experimentally_resolved.max_resolution + self.exp_res_filter_by_resolution = config.heads.experimentally_resolved.filter_by_resolution + self.experimentally_weight = config.heads.experimentally_resolved.weight + self.exp_res_mask = Tensor(1, ms.float32) \ + if not self.exp_res_filter_by_resolution or \ + (self.exp_min_resolution <= self.resolution <= self.exp_max_resolution) else Tensor(0, ms.float32) + + self.ael_min_resolution = config.heads.predicted_aligned_error.min_resolution + self.ael_max_resolution = config.heads.predicted_aligned_error.max_resolution + self.ael_res_filter_by_resolution = config.heads.predicted_aligned_error.filter_by_resolution + self.ael_res_mask = Tensor(1, ms.float32) \ + if not self.ael_res_filter_by_resolution or \ + (self.ael_min_resolution <= self.resolution <= self.ael_max_resolution) else Tensor(0, ms.float32) + self.aligned_one_hot = nn.OneHot(depth=config.heads.predicted_aligned_error.num_bins) + + self.plddt_min_resolution = config.heads.predicted_lddt.min_resolution + self.plddt_max_resolution = config.heads.predicted_lddt.max_resolution + self.plddt_res_filter_by_resolution = config.heads.predicted_lddt.filter_by_resolution + self.plddt_res_mask = Tensor(1, ms.float32) \ + if not self.plddt_res_filter_by_resolution or \ + (self.plddt_min_resolution <= self.resolution <= self.plddt_max_resolution) else Tensor(0, ms.float32) + self.plddt_weight = config.heads.predicted_lddt.weight + + self.masked_one_hot = nn.OneHot(depth=config.heads.masked_msa.num_output, axis=-1) + self.masked_weight = config.heads.masked_msa.weight + self.sidechain_weight_frac = config.structure_module.sidechain.weight_frac + self.angle_norm_weight = config.structure_module.angle_norm_weight + self.chi_weight = config.structure_module.chi_weight + self.chi_pi_periodic = mnp.asarray(residue_constants.chi_pi_periodic, ms.float32) + + self.violation_tolerance_factor = config.structure_module.violation_tolerance_factor + self.clash_overlap_tolerance = config.structure_module.clash_overlap_tolerance + self.sidechain_atom_clamp_distance = config.structure_module.sidechain.atom_clamp_distance + # self.sidechain_atom_clamp_distance = self.sidechain_atom_clamp_distance * 1000 + self.sidechain_length_scale = config.structure_module.sidechain.length_scale + self.fape_clamp_distance = config.structure_module.fape.clamp_distance + self.fape_loss_unit_distance = config.structure_module.fape.loss_unit_distance + self.predicted_lddt_num_bins = config.heads.predicted_lddt.num_bins + self.c_one_hot = nn.OneHot(depth=14) + self.n_one_hot = nn.OneHot(depth=14) + self.zeros = Tensor(0, ms.int32) + self.twos = Tensor(2, ms.int32) + self.dists_mask_i = mnp.eye(14, 14) + self.cys_sg_idx = Tensor(5, ms.int32) + self.train_fold = train_fold + self.sigmoid_cross_entropy = P.SigmoidCrossEntropyWithLogits() + + def softmax_cross_entropy(self, logits, labels): + """Computes softmax cross entropy given logits and one-hot class labels.""" + loss = -mnp.sum(labels * nn.LogSoftmax()(logits), axis=-1) + return mnp.asarray(loss) + + def softmax_cross_entropy_binary(self, logits, labels, binary_mask): + """Computes softmax cross entropy given logits and one-hot class labels.""" + labels_positive = mnp.sum(labels * binary_mask, axis=-1) + pred_positive = mnp.sum(nn.Softmax()(logits) * binary_mask, axis=-1) + loss = -((labels_positive * P.Log()(pred_positive + 1e-10)) + (1 - labels_positive) * P.Log()(1 - pred_positive + 1e-10)) + return mnp.asarray(loss) + + def distogram_loss(self, logits, bin_edges, pseudo_beta, pseudo_beta_mask, sbr_intra_mask, sbr_inter_mask): + """Log loss of a distogram.""" + positions = pseudo_beta + mask = pseudo_beta_mask + + sq_breaks = mnp.square(bin_edges) + dist_t = mnp.square(mnp.expand_dims(positions, axis=-2) - mnp.expand_dims(positions, axis=-3)) + dist2 = P.ReduceSum(True)(dist_t.astype(ms.float32), -1) + aa = (dist2 > sq_breaks).astype(ms.float32) + + true_bins = P.ReduceSum()(aa, -1) + true_bins = true_bins.astype(ms.int32) + errors = self.softmax_cross_entropy(labels=self.distogram_one_hot(true_bins), logits=logits) + square_mask = mnp.expand_dims(mask, axis=-2) * mnp.expand_dims(mask, axis=-1) + + sbr_inter_mask *= square_mask + sbr_intra_mask *= square_mask + avg_error = (P.ReduceSum()(errors * square_mask, (-2, -1)) / + (1e-6 + P.ReduceSum()(square_mask.astype(ms.float32), (-2, -1)))) + # sbr_inter_disto_loss = (P.ReduceSum()(errors * sbr_inter_mask, (-2, -1)) / + # (1e-6 + P.ReduceSum()(sbr_inter_mask.astype(ms.float32), (-2, -1)))) + # sbr_intra_disto_loss = (P.ReduceSum()(errors * sbr_intra_mask, (-2, -1)) / + # (1e-6 + P.ReduceSum()(sbr_intra_mask.astype(ms.float32), (-2, -1)))) + + dist2 = dist2[..., 0] + loss = avg_error + true_dist = mnp.sqrt(1e-6 + dist2) + return loss, true_dist #, sbr_intra_disto_loss, sbr_inter_disto_loss + + def get_mask(self, sbr_mask, asym_id): + sbr_mask = P.Cast()(sbr_mask, ms.float32) + intra_chain_mask = P.Cast()(asym_id[:, None] == asym_id[None, :], ms.float32) + sbr_intra_mask = intra_chain_mask * sbr_mask + sbr_inter_mask = P.Cast()((1 - intra_chain_mask) * sbr_mask, ms.float32) + return sbr_intra_mask, sbr_inter_mask + + + def experimentally_loss(self, experimentally_logits, atom37_atom_exists, all_atom_mask, filter_by_solution): + """experimentally_loss""" + logits = experimentally_logits + + # Does the atom appear in the amino acid? + atom_exists = atom37_atom_exists + # Is the atom resolved in the experiment? Subset of atom_exists, + # *except for OXT* + all_atom_mask = all_atom_mask.astype(mnp.float32) + + xent = self.sigmoid_cross_entropy(logits, all_atom_mask) + loss = P.ReduceSum()(xent * atom_exists) / (1e-8 + P.ReduceSum()(atom_exists.astype(ms.float32))) + loss = loss * filter_by_solution + loss *= self.exp_res_mask + return loss + + def masked_head_loss(self, true_msa, logits, bert_mask): + """masked_head_loss""" + errors = self.softmax_cross_entropy(logits=logits, labels=self.masked_one_hot(true_msa)) + loss = (P.ReduceSum()(errors * bert_mask, (-2, -1)) / + (1e-8 + P.ReduceSum()(bert_mask.astype(ms.float32), (-2, -1)))) + return loss + + + + # todo + def structure_loss(self, atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, + atom14_gt_exists, atom14_atom_exists, final_atom14_positions, atom14_alt_gt_exists, + residue_index, aatype, residx_atom14_to_atom37, lower_bound, upper_bound, seq_mask, + atomtype_radius, angles_sin_cos, um_angles_sin_cos, traj, backbone_affine_tensor, + backbone_affine_mask, rigidgroups_gt_frames, rigidgroups_gt_exists, rigidgroups_alt_gt_frames, + pred_frames, pred_positions, sin_cos_true_chi, torsion_angle_mask, use_clamped_fape, asym_id, + sbr_mask): + """structure_loss""" + atom14_pred_positions = final_atom14_positions + # Compute renaming and violations. + alt_naming_is_better, renamed_atom14_gt_positions, renamed_atom14_gt_exists = \ + compute_renamed_ground_truth(atom14_gt_positions, + atom14_alt_gt_positions, + atom14_atom_is_ambiguous, + atom14_gt_exists, + atom14_pred_positions, + atom14_alt_gt_exists) + (bonds_c_n_loss_mean, angles_ca_c_n_loss_mean, angles_c_n_ca_loss_mean, _, + _, _, clashes_per_atom_loss_sum, _, per_atom_loss_sum, _, _, _, + clashes_per_atom_clash_count, per_atom_clash_count) = \ + get_structural_violations(atom14_atom_exists, residue_index, aatype, residx_atom14_to_atom37, + atom14_pred_positions, asym_id, self.violation_tolerance_factor, + self.clash_overlap_tolerance, lower_bound, upper_bound, atomtype_radius, + self.c_one_hot(self.twos), self.n_one_hot(self.zeros), self.dists_mask_i, + self.cys_sg_idx) + + bond_loss = bonds_c_n_loss_mean + angles_ca_c_n_loss_mean * 0.3 + angles_c_n_ca_loss_mean * 0.3 + + #num_atoms = P.ReduceSum()(atom14_atom_exists.astype(ms.float32)) + num_atoms = P.ReduceSum()(clashes_per_atom_clash_count + per_atom_clash_count) + clash_loss = P.ReduceSum()(clashes_per_atom_loss_sum + per_atom_loss_sum) / (1e-6 + num_atoms) + + structure_violation_loss = bond_loss + clash_loss + + # from structure module result + _, fape_loss, no_clamp, fape_nc_intra, fape_nc_inter, sbr_intra_fape_loss, sbr_inter_fape_loss = \ + backbone(traj, backbone_affine_tensor, backbone_affine_mask, \ + self.fape_clamp_distance, self.fape_loss_unit_distance, use_clamped_fape, asym_id, sbr_mask) + + loss_sidechain = sidechain(alt_naming_is_better, rigidgroups_gt_frames, rigidgroups_alt_gt_frames, + rigidgroups_gt_exists, renamed_atom14_gt_positions, renamed_atom14_gt_exists, + self.sidechain_atom_clamp_distance, self.sidechain_length_scale, pred_frames, + pred_positions) + angle_norm_loss = supervised_chi(seq_mask, aatype, sin_cos_true_chi, torsion_angle_mask, + angles_sin_cos, um_angles_sin_cos, self.chi_weight, + self.angle_norm_weight, self.chi_pi_periodic) + return fape_loss, loss_sidechain, angle_norm_loss, structure_violation_loss, no_clamp, bond_loss, \ + clash_loss, fape_nc_intra, fape_nc_inter, sbr_intra_fape_loss, sbr_inter_fape_loss + + def predicted_lddt_loss(self, final_atom_positions, all_atom_positions, all_atom_mask, predicted_lddt_logits, + filter_by_solution): + """predicted_lddt_loss""" + pred_all_atom_pos = final_atom_positions + true_all_atom_pos = all_atom_positions + lddt_ca = local_distance_difference_test( + predicted_points=pred_all_atom_pos[None, :, 1, :], + true_points=true_all_atom_pos[None, :, 1, :], + true_points_mask=all_atom_mask[None, :, 1:2].astype(mnp.float32), + cutoff=15., + per_residue=True)[0] + + lddt_ca = F.stop_gradient(lddt_ca) + + bin_index = mnp.floor(lddt_ca * self.predicted_lddt_num_bins).astype(ms.int32) + + # protect against out of range for lddt_ca == 1 + bin_index = mnp.minimum(bin_index, self.predicted_lddt_num_bins - 1) + lddt_ca_one_hot = nn.OneHot(depth=self.predicted_lddt_num_bins)(bin_index) + + logits = predicted_lddt_logits + errors = self.softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) + + mask_ca = all_atom_mask[:, 1] + mask_ca = mask_ca.astype(mnp.float32) + loss = P.ReduceSum()(errors * mask_ca) / P.ReduceSum()(P.ReduceSum()(mask_ca) + 1e-8) + loss = loss * filter_by_solution + loss *= self.plddt_res_mask + + return loss + + def aligned_error_loss(self, final_affines, backbone_affine_tensor, backbone_affine_mask, pae_breaks, pae_logits, + filter_by_solution): + """aligned_error_loss""" + # Shape (num_res, 7) predict affine + _, rotation_pd, translation_pd = quaternion_from_tensor(final_affines) + translation_point_pd = vecs_expand_dims(translation_pd, -2) + rotation_pd_tensor = rotation_pd + # Shape (num_res, 7) true affine + _, rotation_gt, translation_gt = quaternion_from_tensor(backbone_affine_tensor) + translation_point_tr = vecs_expand_dims(translation_gt, -2) + rotation_gt_tensor = rotation_gt + mask = backbone_affine_mask + square_mask = (mask[:, None] * mask[None, :]).astype(ms.float32) + breaks = pae_breaks + logits = pae_logits + + local_frames_pd = invert_point(translation_point_pd, rotation_pd_tensor, translation_pd, extra_dims=1) + local_frames_gt = invert_point(translation_point_tr, rotation_gt_tensor, translation_gt, extra_dims=1) + # todo to be checked + error_dist2 = mnp.square(local_frames_pd[0] - local_frames_gt[0]) + \ + mnp.square(local_frames_pd[1] - local_frames_gt[1]) + \ + mnp.square(local_frames_pd[2] - local_frames_gt[2]) + error_dist2 = F.stop_gradient(error_dist2) + # # Compute the squared error for each alignment. + sq_breaks = mnp.square(breaks) + true_bins = P.ReduceSum()((error_dist2[..., None] > sq_breaks).astype(mnp.float32), -1) + + errors = self.softmax_cross_entropy(labels=self.aligned_one_hot(true_bins.astype(ms.int32)), logits=logits) + + loss = (P.ReduceSum()(errors * square_mask, (-2, -1)) / + (1e-8 + P.ReduceSum()(square_mask, (-2, -1)))) + loss = loss * filter_by_solution + loss *= self.ael_res_mask + + return loss + + def distance_rmsd_loss(self, predicted_atom_positions, label_atom_positions, rmsd_mask): + dist1 = P.Sqrt()(P.ReduceSum()(P.Square()(predicted_atom_positions[None]-predicted_atom_positions[:,None]), -1) + 1e-8) + dist2 = P.Sqrt()(P.ReduceSum()(P.Square()(label_atom_positions[None] -label_atom_positions[:,None]) , -1) + 1e-8) + error = P.Square()(dist1 - dist2) + loss = P.Sqrt()(P.ReduceSum()(error * rmsd_mask) / (P.ReduceSum()(rmsd_mask) + 1e-8) + 1e-8)/ self.fape_loss_unit_distance + return loss + + def backbone_drmsd_loss(self, pseudo_beta_pred, pseudo_beta_gt, final_atom_positions, all_atom_positions, mask): + rmsd_loss_cb = self.distance_rmsd_loss(pseudo_beta_pred, pseudo_beta_gt, mask.astype(ms.float32)) + rmsd_loss_ca = self.distance_rmsd_loss(final_atom_positions[:, 1, :], all_atom_positions[:, 1, :], mask.astype(ms.float32)) + rmsd_loss_c = self.distance_rmsd_loss(final_atom_positions[:, 0, :], all_atom_positions[:, 0, :], mask.astype(ms.float32)) + rmsd_loss_n = self.distance_rmsd_loss(final_atom_positions[:, 2, :], all_atom_positions[:, 2, :], mask.astype(ms.float32)) + backbone_drmsd_loss = (rmsd_loss_ca + rmsd_loss_c + rmsd_loss_n + rmsd_loss_cb) + return backbone_drmsd_loss + + def get_asym_centres(self, pos, asym_mask, eps): + pos = P.ExpandDims()(pos, 0) * P.ExpandDims()(asym_mask, 2) # [NC, NR, 3] + return mnp.sum(pos, -2) / (mnp.sum(asym_mask, -1)[..., None] + eps) # [NC, 3] + + def chain_centre_mass_loss(self, pseudo_beta, pseudo_beta_mask, aatype, final_atom_positions, asym_mask, eps=1e-8): + pseudo_beta_pred = pseudo_beta_fn(aatype, final_atom_positions, None) + asym_mask = asym_mask * P.ExpandDims()(pseudo_beta_mask, 0) # [NC, NR] + asym_exists = P.Cast()(asym_mask.sum(-1) > 0, ms.float16) + # asym_exists = asym_mask.any(axis=-1) # [NC, ] + + pred_centres = self.get_asym_centres(pseudo_beta_pred, asym_mask, eps) # [NC, 3] + true_centres = self.get_asym_centres(pseudo_beta, asym_mask, eps) # [NC, 3] + + pred_dists = P.Sqrt()((P.Square()(pred_centres[None] - pred_centres[:, None])).sum(-1) + 1e-8) # [NC, NC] + true_dists = P.Sqrt()((P.Square()(true_centres[None] - true_centres[:, None])).sum(-1) + 1e-8) # [NC, NC] + chain_centre_mass_loss = P.Square()(mnp.clip(P.Abs()(pred_dists - true_dists) - 4, xmin=0, xmax=None)) * 0.0025 + # chain_centre_mass_loss = P.Square()(mnp.clip(pred_dists - true_dists + 4, xmin=None, xmax=0)) * 0.0025 + + chain_centre_mask = (asym_exists[None, :] * asym_exists[:, None]).astype(ms.float32) + chain_centre_mass_loss = (chain_centre_mass_loss * chain_centre_mask).sum() / (chain_centre_mask.sum() + eps) + + return chain_centre_mass_loss + + + def sbr_drmsd_loss(self, final_atom_positions, all_atom_positions, pseudo_beta_gt, aatype, sbr_intra_mask, sbr_inter_mask): + + pseudo_beta_pred = pseudo_beta_fn(aatype, final_atom_positions, None) # CA as CB for glycine + # positional rmsd sbr loss + sbr_intra_drmsd_loss = self.backbone_drmsd_loss(pseudo_beta_pred, pseudo_beta_gt, final_atom_positions, all_atom_positions, \ + sbr_intra_mask) + sbr_inter_drmsd_loss = self.backbone_drmsd_loss(pseudo_beta_pred, pseudo_beta_gt, final_atom_positions, all_atom_positions, \ + sbr_inter_mask) + return sbr_intra_drmsd_loss, sbr_inter_drmsd_loss, pseudo_beta_pred + + def compute_sbr_loss(self, pseudo_pred_dist, bin_edges_sbr, sbr, sbr_intra_mask, sbr_inter_mask, delta=2.0): + not_high_bin = (sbr <= 1.0/(len(bin_edges_sbr) + 1)).astype(ms.float32) + upper_1d = P.Concat()((bin_edges_sbr, Tensor([10000], ms.float32))) + lower_1d = P.Concat()((Tensor([0], ms.float32), bin_edges_sbr)) + upper_2d = (upper_1d-1e6*not_high_bin).max(-1) + lower_2d = (lower_1d+1e6*not_high_bin).min(-1) + lower_error = mnp.clip(lower_2d- delta - pseudo_pred_dist, 0, 30) + upper_error = mnp.clip(pseudo_pred_dist - upper_2d - delta, 0, 30) + error = (lower_error + upper_error)*(upper_2d > lower_2d) + error_inter = (error * sbr_inter_mask).sum() / (sbr_inter_mask.sum() + 1e-8) + error_intra = (error * sbr_intra_mask).sum() / (sbr_intra_mask.sum() + 1e-8) + recall = (error<=0).astype(ms.float32) + recall_inter1 = (recall * sbr_inter_mask).sum() / (sbr_inter_mask.sum() + 1e-8) + recall_intra1 = (recall * sbr_intra_mask).sum() / (sbr_intra_mask.sum() + 1e-8) + return error_intra, error_inter, recall_inter1, recall_intra1 + + def compute_recall(self, pseudo_pred_dist, bin_edges_sbr, sbr, sbr_intra_mask, sbr_inter_mask): + # compute recall + sbr_binary = (sbr > 1.0/(len(bin_edges_sbr) + 1)).astype(ms.float32) + aa = (mnp.expand_dims(pseudo_pred_dist, -1) > bin_edges_sbr).astype(ms.float32) + pred_bins = P.ReduceSum()(aa, -1) + pred_bins = pred_bins.astype(ms.int32) + sbr_pred = mnp.sum(self.distogram_one_hot_sbr(pred_bins) * sbr_binary, axis=-1) + recall_intra = (sbr_pred * sbr_intra_mask).sum() / (sbr_intra_mask.sum() + 1e-8) + recall_inter = (sbr_pred * sbr_inter_mask).sum() / (sbr_inter_mask.sum() + 1e-8) + return recall_intra, recall_inter + + def interface_loss(self, interface_mask, asym_id, pseudo_pred_dist, pseudo_beta_mask, true_dist, delta=1.0, eps=1e-8): + inter_chain_mask = P.Cast()(asym_id[:, None] != asym_id[None, :], ms.float32) + pseudo_pred_dist += (1.0 - pseudo_beta_mask * inter_chain_mask) * 1e9 + # dist += (1.0 - pseudo_beta_mask * inter_chain_mask) * 1e9 + perfect_dist = pseudo_pred_dist + (true_dist > 8) * 1e9 + interface_min_dist = pseudo_pred_dist.min(axis=-1) + + + error = mnp.clip(interface_min_dist - (8.0 + delta), 0.0, 30.0) + error = (error * interface_mask).sum() / (interface_mask.sum() + eps) + + is_interface = P.Cast()(interface_min_dist < 8.0, ms.float32) + is_perfect_interface = P.Cast()(perfect_dist.min(axis=-1) < 8.0, ms.float32) + recall_interface = (is_interface * interface_mask).sum() / interface_mask.sum() + pefect_recall_interface = (is_perfect_interface * interface_mask).sum() / interface_mask.sum() + return error, recall_interface, pefect_recall_interface + + def construct(self, distogram_logits, bin_edges, pseudo_beta, pseudo_beta_mask, experimentally_logits, + atom37_atom_exists, all_atom_mask, true_msa, masked_logits, bert_mask, + final_atom14_positions, residue_index, aatype, residx_atom14_to_atom37, lower_bound, upper_bound, + seq_mask, atomtype_radius, final_affines, pae_breaks, pae_logits, angles_sin_cos, + um_angles_sin_cos, backbone_affine_tensor, backbone_affine_mask, atom14_gt_positions, + atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, atom14_atom_exists, + atom14_alt_gt_exists, final_atom_positions, all_atom_positions, predicted_lddt_logits, traj, + rigidgroups_gt_frames, rigidgroups_gt_exists, rigidgroups_alt_gt_frames, + pred_frames, pred_positions, sin_cos_true_chi, torsion_angle_mask, use_clamped_fape, + filter_by_solution, asym_id, asym_mask, sbr, sbr_mask, interface_mask): + """construct""" + distogram_loss = 0.0 + sbr_intra_disto_loss = 0.0 + sbr_inter_disto_loss = 0.0 + masked_loss = 0.0 + sbr_intra_mask, sbr_inter_mask = self.get_mask(sbr_mask, asym_id) + + + if self.train_fold: + distogram_loss, dist = \ + self.distogram_loss(distogram_logits, bin_edges, pseudo_beta, + pseudo_beta_mask, sbr_intra_mask, sbr_inter_mask) + distogram_loss = distogram_loss * self.distogram_weight # 0.3 + + masked_loss = self.masked_head_loss(true_msa, masked_logits, bert_mask) + masked_loss = self.masked_weight * masked_loss #2 + # masked_loss = Tensor(0.0) + + # self.aligned_error_loss(final_affines, backbone_affine_tensor, backbone_affine_mask, pae_breaks, + # pae_logits, filter_by_solution) + # self.experimentally_loss(experimentally_logits, atom37_atom_exists, all_atom_mask, filter_by_solution) + + fape_loss, loss_sidechain, angle_norm_loss, structure_violation_loss, no_clamp, bond_loss, clash_loss, \ + fape_nc_intra, fape_nc_inter, sbr_intra_fape_loss, sbr_inter_fape_loss = \ + self.structure_loss(atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, + atom14_gt_exists, atom14_atom_exists, final_atom14_positions, + atom14_alt_gt_exists, residue_index, aatype, residx_atom14_to_atom37, + lower_bound, upper_bound, seq_mask, atomtype_radius, angles_sin_cos, + um_angles_sin_cos, traj, backbone_affine_tensor, + backbone_affine_mask, rigidgroups_gt_frames, rigidgroups_gt_exists, + rigidgroups_alt_gt_frames, + pred_frames, pred_positions, sin_cos_true_chi, torsion_angle_mask, use_clamped_fape, + asym_id, sbr_mask) + structure_violation_loss = structure_violation_loss * 0.03 + + predict_lddt_loss = self.predicted_lddt_loss(final_atom_positions, all_atom_positions, all_atom_mask, + predicted_lddt_logits, filter_by_solution) + predict_lddt_loss = self.plddt_weight * predict_lddt_loss # 0.01 + + chain_centre_mass_loss = self.chain_centre_mass_loss(pseudo_beta, pseudo_beta_mask, aatype, + final_atom_positions, asym_mask) + # # todo check whether to use it + aligned_error_loss = self.aligned_error_loss(final_affines, backbone_affine_tensor, + backbone_affine_mask, pae_breaks, pae_logits, filter_by_solution) + aligned_error_loss = aligned_error_loss * 0.1 + + l_fape_side = 0.5 * loss_sidechain + l_fape_backbone = 0.5 * fape_loss + l_anglenorm = angle_norm_loss + + # sbr loss + sbr_intra_drmsd_loss, sbr_inter_drmsd_loss, pseudo_beta_pred \ + = self.sbr_drmsd_loss(final_atom_positions, all_atom_positions, pseudo_beta, aatype, \ + sbr_intra_mask, sbr_inter_mask) + #sbr recall + # bin_edges_sbr = mnp.linspace(8.25, 20.75, 11) + bin_edges_sbr = mnp.arange(4, 33, 1).astype(ms.float32) + pseudo_pred_dist = P.Sqrt()(P.ReduceSum()(P.Square()(pseudo_beta_pred[:, None] - pseudo_beta_pred[None]), -1) + 1e-8) + true_dist = P.Sqrt()(P.ReduceSum()(P.Square()(pseudo_beta[:, None] - pseudo_beta[None]), -1) + 1e-8) + + recall_intra, recall_inter = self.compute_recall(pseudo_pred_dist, bin_edges_sbr, sbr, sbr_intra_mask, sbr_inter_mask) + sbr_intra_disto_loss, sbr_inter_disto_loss, recall_inter1, recall_intra1 = self.compute_sbr_loss(pseudo_pred_dist, bin_edges_sbr, sbr, sbr_intra_mask, sbr_inter_mask) + + + # interface loss + + + sbr_inter_fape_loss = sbr_inter_fape_loss * 0.5 + sbr_intra_fape_loss = sbr_intra_fape_loss * 0.5 + + sbr_inter_drmsd_loss = sbr_inter_drmsd_loss * 0.05 + sbr_intra_drmsd_loss = sbr_intra_drmsd_loss * 0.05 + + sbr_inter_disto_loss *= 0.01 + sbr_intra_disto_loss *= 0.01 + + all_sbr_loss = sbr_intra_disto_loss + sbr_inter_disto_loss + \ + mnp.clip(sbr_inter_fape_loss + sbr_inter_drmsd_loss, 0.0, 1.5) + \ + mnp.clip(sbr_intra_fape_loss + sbr_intra_drmsd_loss, 0.0, 1.5) + + interface_loss, recall_interface, perfect_recall_interface = self.interface_loss(interface_mask, asym_id, pseudo_pred_dist, pseudo_beta_mask, true_dist) + interface_loss *= 0.5 + + loss = l_fape_side + \ + l_fape_backbone + \ + l_anglenorm + \ + distogram_loss + \ + masked_loss + \ + predict_lddt_loss + \ + mnp.clip(structure_violation_loss, 0.0, 1) + \ + aligned_error_loss + \ + mnp.clip(chain_centre_mass_loss, 0.0, 1) + \ + all_sbr_loss + \ + mnp.clip(interface_loss, 0.0, 1) + + loss = loss * P.Sqrt()(P.ReduceSum()(all_atom_mask[:, 0])) + + return loss, l_fape_side, l_fape_backbone, l_anglenorm, distogram_loss, masked_loss, predict_lddt_loss,\ + structure_violation_loss, no_clamp, fape_nc_intra, fape_nc_inter, chain_centre_mass_loss, aligned_error_loss,\ + sbr_inter_fape_loss, sbr_inter_drmsd_loss, sbr_inter_disto_loss,\ + sbr_intra_fape_loss, sbr_intra_drmsd_loss, sbr_intra_disto_loss, interface_loss, \ + recall_intra, recall_inter, recall_interface, perfect_recall_interface, recall_inter1, recall_intra1 + + # structure_violation_loss, no_clamp, bond_loss, clash_loss, chain_centre_mass_loss, aligned_error_loss + # predict_lddt_loss, predict_lddt_loss, predict_lddt_loss, predict_lddt_loss, predict_lddt_loss, predict_lddt_loss + diff --git a/MindSPONGE/applications/research/Grasp/module/lr.py b/MindSPONGE/applications/research/Grasp/module/lr.py new file mode 100644 index 0000000000000000000000000000000000000000..787bb72322efd22b3910744d865e04d0efd09e86 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/lr.py @@ -0,0 +1,35 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate""" + +import math +import numpy as np + + +def cos_decay_lr(start_step, lr_init, lr_min, lr_max, decay_steps, warmup_steps, total_steps): + """cosine decay learning rate""" + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + elif i < decay_steps: + lr = lr_min + 0.5 * (lr_max-lr_min) * (1 + math.cos((i - warmup_steps) / (decay_steps - warmup_steps) * math.pi)) + else: + lr = lr_min + + lr_each_step.append(lr) + lr_each_step = np.array(lr_each_step).astype(np.float32) + return lr_each_step[start_step:] diff --git a/MindSPONGE/applications/research/Grasp/module/structure.py b/MindSPONGE/applications/research/Grasp/module/structure.py new file mode 100644 index 0000000000000000000000000000000000000000..705a6be8bc462da4256438b4956cceb840cb1ad6 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/structure.py @@ -0,0 +1,248 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops import functional as F +from mindsponge1.cell import InvariantPointAttention +import mindsponge1.common.residue_constants as residue_constants +from mindsponge1.cell.initializer import lecun_init +from mindsponge1.common.utils import torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, \ + atom14_to_atom37 +from mindsponge1.common.geometry import initial_affine, quaternion_to_tensor, pre_compose, vecs_scale,\ + vecs_to_tensor, vecs_expand_dims, rots_expand_dims + + +class MultiRigidSidechain(nn.Cell): + """Class to make side chain atoms.""" + + def __init__(self, config, single_repr_dim): + super().__init__() + self.config = config + self.input_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.input_projection_1 = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.relu = nn.ReLU() + self.resblock1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, + initializer_name='relu')) + self.resblock2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.resblock1_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.resblock2_1 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.unnormalized_angles = nn.Dense(self.config.num_channel, 14, + weight_init=lecun_init(self.config.num_channel)) + self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group) + self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions) + self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask) + self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame) + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=1e-12) + + def construct(self, rotation, translation, act, initial_act, aatype): + """Predict side chains using rotation and translation representations. + + Args: + rotation: The rotation matrices. + translation: A translation matrices. + act: updated pair activations from structure module + initial_act: initial act representations (input of structure module) + aatype: Amino acid type representations + + Returns: + angles, positions and new frames + """ + + act1 = self.input_projection(self.relu(act)) + init_act1 = self.input_projection_1(self.relu(initial_act)) + # Sum the activation list (equivalent to concat then Linear). + act = act1 + init_act1 + + # Mapping with some residual blocks. + # resblock1 + old_act = act + act = self.resblock1(self.relu(act)) + act = self.resblock2(self.relu(act)) + act += old_act + # resblock2 + old_act = act + act = self.resblock1_1(self.relu(act)) + act = self.resblock2_1(self.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = self.unnormalized_angles(self.relu(act)) + + unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2]) + angles = self.l2_normalize(unnormalized_angles) + + backb_to_global = ((rotation[0], rotation[1], rotation[2], + rotation[3], rotation[4], rotation[5], + rotation[6], rotation[7], rotation[8]), + (translation[0], translation[1], translation[2])) + + all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles, + self.restype_rigid_group_default_frame) + + pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, + self.restype_atom14_to_rigid_group, + self.restype_atom14_rigid_group_positions, + self.restype_atom14_mask) + + atom_pos = pred_positions + frames = all_frames_to_global + res = (angles, unnormalized_angles, atom_pos, frames) + return res + + +class FoldIteration(nn.Cell): + """A single iteration of the main structure module loop.""" + + def __init__(self, config, pair_dim, single_repr_dim): + super().__init__() + self.config = config + self.drop_out = nn.Dropout(keep_prob=0.9) + self.attention_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition = nn.Dense(self.config.num_channel, config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.relu = nn.ReLU() + self.affine_update = nn.Dense(self.config.num_channel, 6, weight_init='zeros') + self.attention_module = InvariantPointAttention(self.config.num_head, + self.config.num_scalar_qk, + self.config.num_scalar_v, + self.config.num_point_v, + self.config.num_point_qk, + self.config.num_channel, + pair_dim) + self.mu_side_chain = MultiRigidSidechain(self.config.sidechain, single_repr_dim) + self.print = ops.Print() + + def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype): + """construct""" + attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation) + act += attn + act = self.drop_out(act) + act = self.attention_layer_norm(act) + # Transition + input_act = act + act = self.transition(act) + act = self.relu(act) + act = self.transition_1(act) + act = self.relu(act) + act = self.transition_2(act) + + act += input_act + act = self.drop_out(act) + act = self.transition_layer_norm(act) + + # This block corresponds to + # Jumper et al. (2021) Alg. 23 "Backbone update" + # Affine update + affine_update = self.affine_update(act) + quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update) + translation1 = vecs_scale(translation, self.position_scale) + rotation1 = rotation + angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames = \ + self.mu_side_chain(rotation1, translation1, act, initial_act, aatype) + + affine_output = quaternion_to_tensor(quaternion, translation) + quaternion = F.stop_gradient(quaternion) + rotation = F.stop_gradient(rotation) + res = (act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames) + return res + + +class StructureModule(nn.Cell): + """StructureModule as a network head.""" + + def __init__(self, config, single_repr_dim, pair_dim): + super(StructureModule, self).__init__() + self.config = config.structure_module + self.seq_length = config.seq_length + self.fold_iteration = FoldIteration(self.config, pair_dim, single_repr_dim) + self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5) + self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5) + self.num_layer = self.config.num_layer + self.indice0 = Tensor( + np.arange(self.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32")) + self.traj_w = Tensor(np.array([1.] * 4 + [self.config.position_scale] * 3), mstype.float32) + + def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, atom37_atom_exists=None): + """construct""" + sequence_mask = seq_mask[:, None] + act = self.single_layer_norm(single) + initial_act = act + act = self.initial_projection(act) + quaternion, rotation, translation = initial_affine(self.seq_length) + act_2d = self.pair_layer_norm(pair) + # folder iteration + atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, act_iter = \ + self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) + atom14_pred_positions = vecs_to_tensor(atom_pos)[-1] + sidechain_atom_pos = atom_pos + + atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, + residx_atom37_to_atom14, + atom37_atom_exists, + self.indice0) + + structure_traj = affine_output_new * self.traj_w + final_affines = affine_output_new[-1] + final_atom_positions = atom37_pred_positions + final_atom_mask = atom37_atom_exists + rp_structure_module = act_iter + res = (final_atom_positions, final_atom_mask, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj) + return res + + def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, + aatype): + """iteration_operation""" + affine_init = () + angles_sin_cos_init = () + um_angles_sin_cos_init = () + atom_pos_batch = () + frames_batch = () + + for _ in range(self.num_layer): + act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames = \ + self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) + + affine_init = affine_init + (affine_output[None, ...],) + angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],) + um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],) + atom_pos_batch += (mnp.concatenate(vecs_expand_dims(atom_pos, 0), axis=0)[:, None, ...],) + frames_batch += (mnp.concatenate(rots_expand_dims(frames[0], 0) + + vecs_expand_dims(frames[1], 0), axis=0)[:, None, ...],) + affine_output_new = mnp.concatenate(affine_init, axis=0) + angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0) + um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0) + frames_new = mnp.concatenate(frames_batch, axis=1) + atom_pos_new = mnp.concatenate(atom_pos_batch, axis=1) + res = (atom_pos_new, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames_new, act) + return res diff --git a/MindSPONGE/applications/research/Grasp/module/structure_multimer.py b/MindSPONGE/applications/research/Grasp/module/structure_multimer.py new file mode 100644 index 0000000000000000000000000000000000000000..11ac03e5fc5964463a0d27e4b4020cfa158353ca --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/structure_multimer.py @@ -0,0 +1,263 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops import functional as F +from mindspore.ops import operations as P +import mindsponge1.common.residue_constants as residue_constants +from mindsponge1.cell.initializer import lecun_init +from mindsponge1.common.utils import torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, \ + atom14_to_atom37 +from mindsponge1.common.geometry import initial_affine, quaternion_to_tensor, pre_compose, vecs_scale,\ + vecs_to_tensor, vecs_expand_dims, rots_expand_dims +from cell.equivariant import MultimerInvariantPointAttention +# from mindsponge1.cell import InvariantPointAttention + + +class MultiRigidSidechain(nn.Cell): + """Class to make side chain atoms.""" + + def __init__(self, config, single_repr_dim): + super().__init__() + self.config = config + self.input_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.input_projection_1 = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.relu = nn.ReLU() + self.resblock1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, + initializer_name='relu')) + self.resblock2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.resblock1_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.resblock2_1 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.unnormalized_angles = nn.Dense(self.config.num_channel, 14, + weight_init=lecun_init(self.config.num_channel)) + self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group) + self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions) + self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask) + self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame) + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=1e-12) + + def construct(self, rotation, translation, act, initial_act, aatype): + """Predict side chains using rotation and translation representations. + + Args: + rotation: The rotation matrices. + translation: A translation matrices. + act: updated pair activations from structure module + initial_act: initial act representations (input of structure module) + aatype: Amino acid type representations + + Returns: + angles, positions and new frames + """ + + act1 = self.input_projection(self.relu(act)) + init_act1 = self.input_projection_1(self.relu(initial_act)) + # Sum the activation list (equivalent to concat then Linear). + act = act1 + init_act1 + + # Mapping with some residual blocks. + # resblock1 + old_act = act + act = self.resblock1(self.relu(act)) + act = self.resblock2(self.relu(act)) + act += old_act + # resblock2 + old_act = act + act = self.resblock1_1(self.relu(act)) + act = self.resblock2_1(self.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = self.unnormalized_angles(self.relu(act)) + + unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2]) + angles = self.l2_normalize(unnormalized_angles) + + backb_to_global = ((rotation[0], rotation[1], rotation[2], + rotation[3], rotation[4], rotation[5], + rotation[6], rotation[7], rotation[8]), + (translation[0], translation[1], translation[2])) + + all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles, + self.restype_rigid_group_default_frame) + + pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, + self.restype_atom14_to_rigid_group, + self.restype_atom14_rigid_group_positions, + self.restype_atom14_mask) + + atom_pos = pred_positions + frames = all_frames_to_global + res = (angles, unnormalized_angles, atom_pos, frames) + return res + + +class MultimerFoldIteration(nn.Cell): + """A single iteration of the main structure module loop.""" + + def __init__(self, config, pair_dim, single_repr_dim, device_num): + super().__init__() + self.config = config + self.drop_out = nn.Dropout(keep_prob=0.9) + self.attention_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition = nn.Dense(self.config.num_channel, config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.relu = nn.ReLU() + self.affine_update = nn.Dense(self.config.num_channel, 6, weight_init='zeros') + self.attention_module = MultimerInvariantPointAttention(self.config.num_head, + self.config.num_scalar_qk, + self.config.num_scalar_v, + self.config.num_point_v, + self.config.num_point_qk, + self.config.num_channel, + pair_dim, + device_num) + # self.attention_module = InvariantPointAttention(self.config.num_head, + # self.config.num_scalar_qk, + # self.config.num_scalar_v, + # self.config.num_point_v, + # self.config.num_point_qk, + # self.config.num_channel, + # pair_dim) + self.mu_side_chain = MultiRigidSidechain(self.config.sidechain, single_repr_dim) + self.position_scale = self.config.position_scale + + def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype, sbr_act, sbr_mask, interface_mask): + """construct""" + # print("debug self.attention_module", act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype, sbr_act, sbr_mask, interface_mask) + attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation, sbr_act, sbr_mask, interface_mask) + # self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation, sbr_act, sbr_mask, interface_mask) + act += attn + act = self.drop_out(act) + act = self.attention_layer_norm(act) + # Transition + input_act = act + act = self.transition(act) + act = self.relu(act) + act = self.transition_1(act) + act = self.relu(act) + act = self.transition_2(act) + + act += input_act + act = self.drop_out(act) + act = self.transition_layer_norm(act) + # This block corresponds to + # Jumper et al. (2021) Alg. 23 "Backbone update" + # Affine update + affine_update = self.affine_update(act) + quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update) + translation1 = vecs_scale(translation, self.position_scale) # 20.0 + rotation1 = rotation + angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames = \ + self.mu_side_chain(rotation1, translation1, act, initial_act, aatype) + affine_output = quaternion_to_tensor(quaternion, translation) + quaternion = F.stop_gradient(quaternion) + rotation = F.stop_gradient(rotation) + res = (act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames) + return res + + +class MultimerStructureModule(nn.Cell): + """StructureModule as a network head.""" + + def __init__(self, config, single_repr_dim, pair_dim, device_num): + super(MultimerStructureModule, self).__init__() + self.config = config.structure_module + self.seq_length = config.seq_length + self.fold_iteration = MultimerFoldIteration(self.config, pair_dim, single_repr_dim, device_num) + self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5) + self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5) + self.num_layer = self.config.num_layer + self.indice0 = Tensor( + np.arange(self.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32")) + self.traj_w = Tensor(np.array([1.] * 4 + [self.config.position_scale] * 3), mstype.float32) + self.concat_0_3_3 = P.Concat(0).shard(((1, device_num, 1), (1, device_num, 1), (1, device_num, 1))) + self.concat_1_8_4 = P.Concat(1).shard(((1, 1, device_num, 1), (1, 1, device_num, 1), (1, 1, device_num, 1), (1, 1, device_num, 1), + (1, 1, device_num, 1), (1, 1, device_num, 1), (1, 1, device_num, 1), (1, 1, device_num, 1))) + + def construct(self, single, pair, seq_mask, aatype, + sbr_act, sbr_mask, interface_mask, + residx_atom37_to_atom14=None, atom37_atom_exists=None): + """construct""" + sequence_mask = seq_mask[:, None] + act = self.single_layer_norm(single) + initial_act = act + act = self.initial_projection(act) + quaternion, rotation, translation = initial_affine(self.seq_length) + act_2d = self.pair_layer_norm(pair) + + # folder iteration + # print("MultimerStructureModule construct act", act) + atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, act_iter = \ + self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype, sbr_act, sbr_mask, interface_mask) + + atom14_pred_positions = vecs_to_tensor(atom_pos)[-1] + sidechain_atom_pos = atom_pos + + atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, + residx_atom37_to_atom14, + atom37_atom_exists, + self.indice0) + structure_traj = affine_output_new * self.traj_w + final_affines = affine_output_new[-1] + final_atom_positions = atom37_pred_positions + final_atom_mask = atom37_atom_exists + rp_structure_module = act_iter + res = (final_atom_positions, final_atom_mask, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj) + return res + + def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, + aatype, sbr_act, sbr_mask, interface_mask): + """iteration_operation""" + affine_init = () + angles_sin_cos_init = () + um_angles_sin_cos_init = () + atom_pos_batch = () + frames_batch = () + + for _ in range(self.num_layer): + act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames = self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype, sbr_act, sbr_mask, interface_mask) + affine_init = affine_init + (affine_output[None, ...],) + angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],) + um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],) + atom_pos_batch += (self.concat_0_3_3(vecs_expand_dims(atom_pos, 0))[:, None, ...],) + frames_batch += (mnp.concatenate(rots_expand_dims(frames[0], 0) + + vecs_expand_dims(frames[1], 0), axis=0)[:, None, ...],) + affine_output_new = mnp.concatenate(affine_init, axis=0) + angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0) + um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0) + frames_new = mnp.concatenate(frames_batch, axis=1) + atom_pos_new = self.concat_1_8_4(atom_pos_batch) + res = (atom_pos_new, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames_new, act) + return res \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/module/template_embedding.py b/MindSPONGE/applications/research/Grasp/module/template_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..df42a9adbd38dbcfd0fca2e7edf208b40fbe22ae --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/template_embedding.py @@ -0,0 +1,570 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''TEMPLATE''' +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore import Tensor +from mindsponge1.cell.initializer import lecun_init +from mindsponge1.common.utils import dgram_from_positions, _memory_reduce, pseudo_beta_fn#, DgramFromPositionsCell +from mindsponge1.common.geometry import make_transform_from_reference, quat_affine, invert_point +from mindsponge1.common.residue_constants import atom_order +from mindsponge1.cell import Attention, TriangleAttention, Transition, TriangleMultiplication +from common.geometry import multimer_rigids_get_unit_vector +# from mindspore import lazy_inline +# from mindspore import Layout + + +class TemplatePairStack(nn.Cell): + '''template pair stack''' + + def __init__(self, config): + super(TemplatePairStack, self).__init__() + self.config = config.template.template_pair_stack + self.num_block = self.config.num_block + batch_size = 0 + self.slice = config.slice.template_pair_stack + start_node_cfg = self.config.triangle_attention_starting_node + self.triangle_attention_starting_node = TriangleAttention(start_node_cfg.orientation, + start_node_cfg.num_head, + start_node_cfg.key_dim, + start_node_cfg.gating, + 64, + batch_size, + self.slice.triangle_attention_starting_node) + end_node_cfg = self.config.triangle_attention_ending_node + self.triangle_attention_ending_node = TriangleAttention(end_node_cfg.orientation, + end_node_cfg.num_head, + end_node_cfg.key_dim, + end_node_cfg.gating, + 64, + batch_size, + self.slice.triangle_attention_ending_node) + # Hard Code + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + 64, + batch_size, + self.slice.pair_transition) + + mul_outgoing_cfg = self.config.triangle_multiplication_outgoing + self.triangle_multiplication_outgoing = TriangleMultiplication(mul_outgoing_cfg.num_intermediate_channel, + mul_outgoing_cfg.equation, + layer_norm_dim=64, + batch_size=batch_size) + mul_incoming_cfg = self.config.triangle_multiplication_incoming + self.triangle_multiplication_incoming = TriangleMultiplication(mul_incoming_cfg.num_intermediate_channel, + mul_incoming_cfg.equation, + layer_norm_dim=64, + batch_size=batch_size) + + def construct(self, pair_act, pair_mask, index=None): + if not self.num_block: + return pair_act + + pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index) + pair_act = pair_act + self.pair_transition(pair_act, index) + return pair_act + + +class SingleTemplateEmbedding(nn.Cell): + '''single template embedding''' + + def __init__(self, config, mixed_precision): + super(SingleTemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_bins = self.config.dgram_features.num_bins + self.min_bin = self.config.dgram_features.min_bin + self.max_bin = self.config.dgram_features.max_bin + + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.embedding2d = nn.Dense(88, self.num_channels, + weight_init=lecun_init(88, initializer_name='relu')) + # if is_training: + template_layers = nn.CellList() + for _ in range(self.config.template_pair_stack.num_block): + template_pair_stack_block = TemplatePairStack(config) + template_layers.append(template_pair_stack_block) + self.template_pair_stack = template_layers + + self.one_hot = nn.OneHot(depth=22, axis=-1) + self.n, self.ca, self.c = [atom_order[a] for a in ('N', 'CA', 'C')] + + self.use_template_unit_vector = self.config.use_template_unit_vector + layer_norm_dim = 64 + self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5) + self.num_block = self.config.template_pair_stack.num_block + self.batch_block = 4 + + def construct(self, mask_2d, template_aatype, template_all_atom_masks, template_all_atom_positions, + template_pseudo_beta_mask, template_pseudo_beta): + '''construct''' + num_res = template_aatype[0, ...].shape[0] + template_mask_2d_temp = P.ExpandDims()(template_pseudo_beta_mask, -1) * \ + P.ExpandDims()(template_pseudo_beta_mask, 1) + template_dgram_temp = dgram_from_positions(template_pseudo_beta, self.num_bins, self.min_bin, + self.max_bin, self._type) + + to_concat_temp = (template_dgram_temp, P.ExpandDims()(template_mask_2d_temp, -1)) + aatype_temp = self.one_hot(template_aatype) + aatype_temp = P.Cast()(aatype_temp, self._type) + to_concat_temp = to_concat_temp + (P.Tile()(P.ExpandDims()(aatype_temp, 1), (1, num_res, 1, 1)), + P.Tile()(P.ExpandDims()(aatype_temp, 2), (1, 1, num_res, 1))) + + rot_temp, trans_temp = make_transform_from_reference(template_all_atom_positions[:, :, self.n], + template_all_atom_positions[:, :, self.ca], + template_all_atom_positions[:, :, self.c]) + + _, rotation_tmp, translation_tmp = quat_affine(None, trans_temp, rot_temp) + points_tmp = [P.ExpandDims()(translation_tmp[0], -2), + P.ExpandDims()(translation_tmp[1], -2), + P.ExpandDims()(translation_tmp[2], -2)] + affine_vec_tmp = invert_point(points_tmp, rotation_tmp, translation_tmp, extra_dims=1) + inv_distance_scalar_tmp = P.Rsqrt()(1e-6 + P.Square()(affine_vec_tmp[0]) + P.Square()(affine_vec_tmp[1]) + \ + P.Square()(affine_vec_tmp[2])) + template_mask_tmp = (template_all_atom_masks[:, :, self.n] * + template_all_atom_masks[:, :, self.ca] * + template_all_atom_masks[:, :, self.c]) + template_mask_2d_tmp = P.ExpandDims()(template_mask_tmp, -1) * P.ExpandDims()(template_mask_tmp, 1) + + inv_distance_scalar_tmp = inv_distance_scalar_tmp * template_mask_2d_tmp + unit_vector_tmp = (P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[0], -1), + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[1], -1), + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[2], -1)) + + if not self.use_template_unit_vector: + unit_vector_tmp = (P.ZerosLike()(unit_vector_tmp[0]), P.ZerosLike()(unit_vector_tmp[1]), + P.ZerosLike()(unit_vector_tmp[2])) + to_concat_temp = to_concat_temp + unit_vector_tmp + (P.ExpandDims()(template_mask_2d_tmp, -1),) + act_tmp = P.Concat(-1)(to_concat_temp) + + act_tmp = act_tmp * P.ExpandDims()(template_mask_2d_tmp, -1) + act_tmp = self.embedding2d(act_tmp) + + act_tmp = P.Split(0, self.batch_block)(act_tmp) + act = () + for i in range(self.batch_block): + act = act + (P.Squeeze()(act_tmp[i]),) + + output = [] + for i in range(self.batch_block): + act_batch = act[i] + for j in range(self.num_block): + act_batch = self.template_pair_stack[j](act_batch, mask_2d) + slice_act = P.Reshape()(act_batch, ((1,) + P.Shape()(act_batch))) + output.append(slice_act) + + act_tmp_loop = P.Concat()(output) + act_tmp = self.output_layer_norm(act_tmp_loop) + return act_tmp + + +class TemplateEmbedding(nn.Cell): + '''template embedding''' + + def __init__(self, config, mixed_precision=True): + super(TemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.template_embedder = SingleTemplateEmbedding(config, mixed_precision) + self.template_pointwise_attention = Attention(self.config.attention.num_head, + self.config.attention.key_dim, + self.config.attention.gating, + q_data_dim=128, m_data_dim=64, + output_dim=128, batch_size=None) + self.slice_num = config.slice.template_embedding + + + def compute(self, flat_query, flat_templates, input_mask): + embedding = self.template_pointwise_attention(flat_query, flat_templates, input_mask, index=None, + nonbatched_bias=None) + return embedding + + + def construct(self, query_embedding, template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, mask_2d): + '''construct''' + num_templates = template_mask.shape[0] + num_channels = self.num_channels + num_res = query_embedding.shape[0] + query_num_channels = query_embedding.shape[-1] + mask_2d = F.depend(mask_2d, query_embedding) + template_pair_representation = self.template_embedder(mask_2d, template_aatype, + template_all_atom_masks, template_all_atom_positions, + template_pseudo_beta_mask, + template_pseudo_beta) + flat_query = P.Reshape()(query_embedding, (num_res * num_res, 1, query_num_channels)) + flat_templates = P.Reshape()( + P.Transpose()(template_pair_representation, (1, 2, 0, 3)), + (num_res * num_res, num_templates, num_channels)) + template_mask_bias = P.ExpandDims()(P.ExpandDims()(P.ExpandDims()(template_mask, 0), 1), 2) - 1.0 + input_mask = 1e4 * template_mask_bias + batched_inputs = (flat_query, flat_templates) + nonbatched_inputs = (input_mask,) + embedding = _memory_reduce(self.compute, batched_inputs, nonbatched_inputs, self.slice_num) + embedding = P.Reshape()(embedding, (num_res, num_res, query_num_channels)) + # No gradients if no templates. + embedding = embedding * (P.ReduceSum()(template_mask) > 0.) + return embedding + + +class MultimerTemplatePairStack(nn.Cell): + '''multimer template pair stack''' + + def __init__(self, config, device_num): + super(MultimerTemplatePairStack, self).__init__() + self.config = config.template.template_pair_stack + self.num_block = self.config.num_block + batch_size = 0 + self.slice = config.slice.template_pair_stack + start_node_cfg = self.config.triangle_attention_starting_node + self.triangle_attention_starting_node = TriangleAttention(start_node_cfg.orientation, + start_node_cfg.num_head, + start_node_cfg.key_dim, + start_node_cfg.gating, + 64, + device_num, + batch_size, + self.slice.triangle_attention_starting_node) + end_node_cfg = self.config.triangle_attention_ending_node + self.triangle_attention_ending_node = TriangleAttention(end_node_cfg.orientation, + end_node_cfg.num_head, + end_node_cfg.key_dim, + end_node_cfg.gating, + 64, + device_num, + batch_size, + self.slice.triangle_attention_ending_node) + # Hard Code + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + 64, + device_num, + batch_size, + self.slice.pair_transition) + + mul_outgoing_cfg = self.config.triangle_multiplication_outgoing + self.triangle_multiplication_outgoing = TriangleMultiplication(mul_outgoing_cfg.num_intermediate_channel, + mul_outgoing_cfg.equation, + 64, + device_num, + batch_size=batch_size) + mul_incoming_cfg = self.config.triangle_multiplication_incoming + self.triangle_multiplication_incoming = TriangleMultiplication(mul_incoming_cfg.num_intermediate_channel, + mul_incoming_cfg.equation, + 64, + device_num, + batch_size=batch_size) + self.add = P.Add().shard(((1, device_num, 1),(1, device_num, 1))) + + def construct(self, pair_act, pair_mask, index=None): + if not self.num_block: + return pair_act + # print("debug pair_act 277", pair_act) + pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index) + # pair_act = self.add(pair_act, self.triangle_multiplication_outgoing(pair_act, pair_mask, index)) + # print("debug pair_act 279", pair_act) + pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index) + # pair_act = self.add(pair_act, self.triangle_multiplication_incoming(pair_act, pair_mask, index)) + # print("debug pair_act 281", pair_act) + pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index) + # pair_act = self.add(pair_act, self.triangle_attention_starting_node(pair_act, pair_mask, index)) + # print("debug pair_act 283", pair_act) + pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index) + # pair_act = self.add(pair_act, self.triangle_attention_ending_node(pair_act, pair_mask, index)) + # print("debug pair_act 285", pair_act) + pair_act = pair_act + self.pair_transition(pair_act, index) + # pair_act = self.add(pair_act, self.pair_transition(pair_act, index)) + # print("debug pair_act 287", pair_act) + return pair_act + + +class MultimerSingleTemplateEmbedding(nn.Cell): + '''multimer single template embedding''' + + def __init__(self, config, mixed_precision, device_num): + super(MultimerSingleTemplateEmbedding, self).__init__() + self.is_training = config.is_training + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_bins = self.config.dgram_features.num_bins + self.min_bin = self.config.dgram_features.min_bin + self.max_bin = self.config.dgram_features.max_bin + + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.template_dgram_temp_dense = nn.Dense(39, self.num_channels, + weight_init=lecun_init(39, initializer_name='relu')) + self.template_mask_2d_temp_dense = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.aatype_temp_0 = nn.Dense(22, self.num_channels, + weight_init=lecun_init(22, initializer_name='relu')) + self.aatype_temp_1 = nn.Dense(22, self.num_channels, + weight_init=lecun_init(22, initializer_name='relu')) + self.unit_vector_0 = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.unit_vector_1 = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.unit_vector_2 = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.backbone_mask_2d_dense = nn.Dense(1, self.num_channels, + weight_init=lecun_init(1, initializer_name='relu')) + self.embedding2d = nn.Dense(128, self.num_channels, + weight_init=lecun_init(128, initializer_name='relu')) + template_layers = nn.CellList() + for _ in range(self.config.template_pair_stack.num_block): + # print("debug MultimerSingleTemplateEmbedding round", _) + template_pair_stack_block = MultimerTemplatePairStack(config, device_num) + if self.is_training: + template_pair_stack_block.recompute() + template_layers.append(template_pair_stack_block) + self.template_pair_stack = template_layers + + self.one_hot = nn.OneHot(depth=22, axis=-1) + self.n, self.ca, self.c = [atom_order[a] for a in ('N', 'CA', 'C')] + + layer_norm_dim = 64 + self.query_embedding_norm = nn.LayerNorm([128,], epsilon=1e-5) + self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5) + self.num_block = self.config.template_pair_stack.num_block + self.batch_block = 4 + + self.squeeze = P.Squeeze() + self.gather2 = P.Gather().shard(((1,1),())) + self.gather3 = P.Gather().shard(((1,1,1),())) + self.gather4 = P.Gather().shard(((1,1,1,1),())) + self.expand = P.ExpandDims().shard(((device_num,),)) + # self.dgram_from_positions_cell = DgramFromPositionsCell(self.num_bins, self.min_bin, self.max_bin, self._type, device_num) + + + def construct(self, pair_activations, template_aatype, + template_all_atom_positions, template_all_atom_mask, + padding_mask_2d, multichain_mask_2d): + '''construct''' + # print("debug MultimerSingleTemplateEmbedding", pair_activations, template_aatype, template_all_atom_positions, template_all_atom_mask, padding_mask_2d, multichain_mask_2d) + pair_activations = self.query_embedding_norm(pair_activations) + + # num_res, _, query_num_channels = pair_activations.shape + + # scan_init = mnp.zeros((num_res, num_res, self.num_channels), dtype=self._type) + scan_init = None + slice_act = None + # print("scan_init] ",scan_init) + # slice_act = None + for i in range(self.batch_block): + single_template_aatype = self.squeeze(self.gather2(template_aatype, Tensor(i), 0)) + single_template_all_atom_masks = self.squeeze(self.gather3(template_all_atom_mask, Tensor(i), 0)) + single_template_all_positions = self.squeeze(self.gather4(template_all_atom_positions, Tensor(i), 0)) + + template_pseudo_beta, template_pseudo_beta_mask = pseudo_beta_fn(single_template_aatype, + single_template_all_positions, + single_template_all_atom_masks) + # single_template_pseudo_beta = self.squeeze(self.gather3(template_pseudo_beta, Tensor(i), 0))#P.Squeeze()(template_pseudo_beta[i,...]) + # single_template_pseudo_beta_mask = self.squeeze(self.gather2(template_pseudo_beta_mask, Tensor(i), 0)) + + template_mask_2d_temp = self.expand(template_pseudo_beta_mask, -1) * self.expand(template_pseudo_beta_mask, 0) + + template_mask_2d_temp = template_mask_2d_temp * multichain_mask_2d + + template_dgram_temp = dgram_from_positions(template_pseudo_beta, self.num_bins, self.min_bin, + self.max_bin, self._type) + # template_dgram_temp = self.dgram_from_positions_cell(template_pseudo_beta) + # template_dgram_temp = self.squeeze(self.gather4(template_dgram_temp_raw, Tensor(i), 0))#P.Squeeze()(template_dgram_temp_raw[i,...]) + template_dgram_temp *= template_mask_2d_temp[..., None] + + act_tmp = self.template_dgram_temp_dense(template_dgram_temp) + + act_tmp += self.template_mask_2d_temp_dense((P.ExpandDims()(template_mask_2d_temp, -1))) + + aatype_temp = self.one_hot(single_template_aatype) + + aatype_temp = P.Cast()(aatype_temp, self._type) + + act_tmp += self.aatype_temp_0((P.ExpandDims()(aatype_temp, 0))) + + act_tmp += self.aatype_temp_1((P.ExpandDims()(aatype_temp, 1))) + + backbone_mask = (single_template_all_atom_masks[:, self.n] * + single_template_all_atom_masks[:, self.ca] * + single_template_all_atom_masks[:, self.c]) + + unit_vector = multimer_rigids_get_unit_vector(single_template_all_positions[:, self.n], + single_template_all_positions[:, self.ca], + single_template_all_positions[:, self.c]) + + backbone_mask_2d = (P.ExpandDims()(backbone_mask, -1)) * (P.ExpandDims()(backbone_mask, 0)) + + backbone_mask_2d *= multichain_mask_2d + + digonal_mask = 1 - mnp.eye(multichain_mask_2d.shape[0]) + # unit_vector = (P.Squeeze()(unit_vector_raw[0][i]), P.Squeeze()(unit_vector_raw[1][i]), P.Squeeze()(unit_vector_raw[2][i])) + + unit_vector = (P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[0], -1), + P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[1], -1), + P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[2], -1)) + + # unit_vector = (P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[0], -1), + # P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[1], -1), + # P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[2], -1)) + + act_tmp += self.unit_vector_0(unit_vector[0]) + # print("debug act_tmp 379", act_tmp) + act_tmp += self.unit_vector_1(unit_vector[1]) + # print("debug act_tmp 381", act_tmp) + act_tmp += self.unit_vector_2(unit_vector[2]) + # print("debug act_tmp 383", act_tmp) + act_tmp += self.backbone_mask_2d_dense(P.ExpandDims()(backbone_mask_2d, -1)) + # print("debug act_tmp 385", act_tmp) + act_tmp += self.embedding2d(pair_activations) + if i > 0: + act_tmp = F.depend(act_tmp, slice_act) + for j in range(self.num_block): + act_tmp = self.template_pair_stack[j](act_tmp, padding_mask_2d) + slice_act = self.output_layer_norm(act_tmp) + if scan_init is None: + scan_init = slice_act + else: + scan_init += slice_act + # scan_init += self.output_layer_norm(act_tmp) + + return scan_init + # num_templates = template_aatype.shape[0] + + # # template_pseudo_beta_mask (1, 248) + # template_pseudo_beta, template_pseudo_beta_mask = pseudo_beta_fn(template_aatype, + # template_all_atom_positions, + # template_all_atom_mask) + # # (1, 248, 1) (1, 1, 248) + # template_mask_2d_temp = P.ExpandDims()(template_pseudo_beta_mask, -1) * \ + # P.ExpandDims()(template_pseudo_beta_mask, 1) + + # # (1, 248, 248) * (248, 248) multichain_mask_2d (62, 248) -> (248, 248) + # template_mask_2d_temp *= multichain_mask_2d + + # template_dgram_temp = dgram_from_positions(template_pseudo_beta, self.num_bins, self.min_bin, + # self.max_bin, self._type) + + # # (1, 248, 248, 39) * (1, 248, 248, 1) + # template_dgram_temp *= template_mask_2d_temp[..., None] + + # # weight: (64, 39) + # # input: (61504, 39) -- (1, 248, 248, 39) + # act_tmp = self.template_dgram_temp_dense(template_dgram_temp) + + + # act_tmp += self.template_mask_2d_temp_dense((P.ExpandDims()(template_mask_2d_temp, -1))) + # # print("debug act_tmp 356", act_tmp) + # aatype_temp = self.one_hot(template_aatype) + # aatype_temp = P.Cast()(aatype_temp, self._type) + # act_tmp += self.aatype_temp_0((P.ExpandDims()(aatype_temp, 1))) + # # print("debug act_tmp 359", act_tmp) + # act_tmp += self.aatype_temp_1((P.ExpandDims()(aatype_temp, 2))) + # # print("debug act_tmp 362", act_tmp) + # backbone_mask = (template_all_atom_mask[:, :, self.n] * + # template_all_atom_mask[:, :, self.ca] * + # template_all_atom_mask[:, :, self.c]) + # # print("debug backbone_mask", backbone_mask) + # # print("debug template_all_atom_positions", template_all_atom_positions) + # unit_vector = multimer_rigids_get_unit_vector(template_all_atom_positions[:, :, self.n], + # template_all_atom_positions[:, :, self.ca], + # template_all_atom_positions[:, :, self.c]) + # # print("debug unit_vector 370", unit_vector) + # backbone_mask_2d = (P.ExpandDims()(backbone_mask, -1)) * (P.ExpandDims()(backbone_mask, 1)) + # # print("debug backbone_mask_2d 372", backbone_mask_2d) + # backbone_mask_2d *= multichain_mask_2d + # # digonal_mask = 1 - self.eye + # digonal_mask = 1 - mnp.eye(multichain_mask_2d.shape[0]) + + # # digonal_mask = 1 - self.eye(multichain_mask_2d.shape[0], multichain_mask_2d.shape[0], mstype.float32) + # # print("debug digonal_mask 375", digonal_mask) + # unit_vector = (P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[0], -1), + # P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[1], -1), + # P.ExpandDims()(backbone_mask_2d * digonal_mask * unit_vector[2], -1)) + # # print("debug unit_vector 377", unit_vector) + # pair_activations = self.query_embedding_norm(pair_activations) + # num_res, _, query_num_channels = pair_activations.shape + # act_tmp += self.unit_vector_0(unit_vector[0]) + # # print("debug act_tmp 379", act_tmp) + # act_tmp += self.unit_vector_1(unit_vector[1]) + # # print("debug act_tmp 381", act_tmp) + # act_tmp += self.unit_vector_2(unit_vector[2]) + # # print("debug act_tmp 383", act_tmp) + # act_tmp += self.backbone_mask_2d_dense(P.ExpandDims()(backbone_mask_2d, -1)) + # # print("debug act_tmp 385", act_tmp) + # act_tmp += self.embedding2d(pair_activations) + # # print("debug act_tmp 387", act_tmp) + # # print("act_tmp's shape:", act_tmp.shape) Tensor(shape=[4], dtype=Int64, value=[ 4 256 256 64]) + + # act_tmp = P.Split(0, self.batch_block)(act_tmp) + # # print("debug act_tmp 390", act_tmp) + # scan_init = mnp.zeros((num_res, num_res, self.num_channels), dtype=self._type) + # act = () + # for i in range(self.batch_block): + # # print("debug act 394", "act", act, "act_tmp", act_tmp) + # act = act + (P.Squeeze()(act_tmp[i]),) + + # for i in range(self.batch_block): + # act_batch = act[i] + # for j in range(self.num_block): + # # print("debug MultimerSingleTemplateEmbedding act_batch round", j, "act_batch", act_batch, "padding_mask_2d", padding_mask_2d) + # act_batch = self.template_pair_stack[j](act_batch, padding_mask_2d) + # # print("debug MultimerSingleTemplateEmbedding round", i, "scan_init", scan_init, "act_batch", act_batch) + # scan_init += self.output_layer_norm(act_batch) + # return scan_init + + +class MultimerTemplateEmbedding(nn.Cell): + '''multimer template embedding''' + # @lazy_inline + def __init__(self, config, device_num, mixed_precision=True): + super(MultimerTemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.template_embedder = MultimerSingleTemplateEmbedding(config, mixed_precision, device_num) + self.relu = nn.ReLU() + self.output_linear = nn.Dense(self.num_channels, config.pair_channel, + weight_init=lecun_init(self.num_channels, initializer_name='relu')) + + def construct(self, pair_activations, template_aatype, template_all_atom_mask, template_all_atom_positions, + padding_mask_2d, multichain_mask_2d): + '''construct''' + num_templates = template_aatype.shape[0] + # print("num_templates: ", num_templates.shape) + embedding = self.template_embedder(pair_activations, template_aatype, + template_all_atom_positions, + template_all_atom_mask, + padding_mask_2d, + multichain_mask_2d) + embedding = embedding / num_templates + embedding = self.relu(embedding) + output = self.output_linear(embedding) + return output \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/module/template_embedding_new.py b/MindSPONGE/applications/research/Grasp/module/template_embedding_new.py new file mode 100644 index 0000000000000000000000000000000000000000..74f54117382e5e9b2cf7ee6c5d60eeae1fd98c4c --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/module/template_embedding_new.py @@ -0,0 +1,477 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''TEMPLATE''' +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindsponge.cell.initializer import lecun_init +from mindsponge.common.utils import dgram_from_positions, _memory_reduce +from mindsponge.common.geometry import make_transform_from_reference, quat_affine, invert_point +from mindsponge.common.residue_constants import atom_order +from mindsponge.cell import Attention, TriangleAttention, Transition, TriangleMultiplication + + +class TemplatePairStack(nn.Cell): + '''template pair stack''' + + def __init__(self, config): + super(TemplatePairStack, self).__init__() + self.config = config.template.template_pair_stack + self.num_block = self.config.num_block + batch_size = 0 + self.slice = config.slice.template_pair_stack + start_node_cfg = self.config.triangle_attention_starting_node + self.triangle_attention_starting_node = \ + TriangleAttention(start_node_cfg.orientation, + start_node_cfg.num_head, + start_node_cfg.key_dim, + start_node_cfg.gating, + 64, + batch_size, + self.slice.triangle_attention_starting_node) + end_node_cfg = self.config.triangle_attention_ending_node + self.triangle_attention_ending_node = \ + TriangleAttention(end_node_cfg.orientation, + end_node_cfg.num_head, + end_node_cfg.key_dim, + end_node_cfg.gating, + 64, + batch_size, + self.slice.triangle_attention_ending_node) + + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + 64, + batch_size, + self.slice.pair_transition) + + mul_outgoing_cfg = self.config.triangle_multiplication_outgoing + self.triangle_multiplication_outgoing = \ + TriangleMultiplication(mul_outgoing_cfg.num_intermediate_channel, + mul_outgoing_cfg.equation, + layer_norm_dim=64, + batch_size=batch_size) + mul_incoming_cfg = self.config.triangle_multiplication_incoming + self.triangle_multiplication_incoming = \ + TriangleMultiplication(mul_incoming_cfg.num_intermediate_channel, + mul_incoming_cfg.equation, + layer_norm_dim=64, + batch_size=batch_size) + + def construct(self, pair_act, pair_mask, index=None): + "construct" + if not self.num_block: + return pair_act + + pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index) + pair_act = pair_act + self.pair_transition(pair_act, index) + return pair_act + + +class SingleTemplateEmbedding(nn.Cell): + '''single template embedding''' + + def __init__(self, config, is_training, mixed_precision): + super(SingleTemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_bins = self.config.dgram_features.num_bins + self.min_bin = self.config.dgram_features.min_bin + self.max_bin = self.config.dgram_features.max_bin + + self.num_channels = self.config.template_pair_stack.triangle_attention_ending_node.value_dim + self.embedding2d = nn.Dense(88, self.num_channels, + weight_init=lecun_init(88, initializer_name='relu')) + + template_layers = nn.CellList() + for _ in range(self.config.template_pair_stack.num_block): + template_pair_stack_block = TemplatePairStack(config) + if is_training: + template_pair_stack_block.recompute() + template_layers.append(template_pair_stack_block) + self.template_pair_stack = template_layers + + self.one_hot = nn.OneHot(depth=22, axis=-1) + self.n, self.ca, self.c = [atom_order[a] for a in ('N', 'CA', 'C')] + + self.use_template_unit_vector = self.config.use_template_unit_vector + layer_norm_dim = 64 + self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5) + self.num_block = self.config.template_pair_stack.num_block + self.batch_block = 4 + + def construct(self, mask_2d, template_aatype, template_all_atom_masks, + template_all_atom_positions, template_pseudo_beta_mask, + template_pseudo_beta): + '''construct''' + num_res = template_aatype[0, ...].shape[0] + template_mask_2d_temp = P.ExpandDims()(template_pseudo_beta_mask, -1) * \ + P.ExpandDims()(template_pseudo_beta_mask, 1) + template_dgram_temp = dgram_from_positions(template_pseudo_beta, + self.num_bins, self.min_bin, + self.max_bin, self._type) + + to_concat_temp = (template_dgram_temp, P.ExpandDims()(template_mask_2d_temp, -1)) + aatype_temp = self.one_hot(template_aatype) + aatype_temp = P.Cast()(aatype_temp, self._type) + to_concat_temp = to_concat_temp + (P.Tile()(P.ExpandDims()(aatype_temp, 1), + (1, num_res, 1, 1)), + P.Tile()(P.ExpandDims()(aatype_temp, 2), + (1, 1, num_res, 1))) + + rot_temp, trans_temp \ + = make_transform_from_reference(template_all_atom_positions[:, :, self.n], + template_all_atom_positions[:, :, self.ca], + template_all_atom_positions[:, :, self.c]) + + _, rotation_tmp, translation_tmp = quat_affine(None, trans_temp, rot_temp) + points_tmp = [P.ExpandDims()(translation_tmp[0], -2), + P.ExpandDims()(translation_tmp[1], -2), + P.ExpandDims()(translation_tmp[2], -2)] + affine_vec_tmp = invert_point(points_tmp, rotation_tmp, translation_tmp, extra_dims=1) + inv_distance_scalar_tmp = P.Rsqrt()(1e-6 + P.Square()(affine_vec_tmp[0]) + \ + P.Square()(affine_vec_tmp[1]) + + P.Square()(affine_vec_tmp[2])) + template_mask_tmp = (template_all_atom_masks[:, :, self.n] * + template_all_atom_masks[:, :, self.ca] * + template_all_atom_masks[:, :, self.c]) + template_mask_2d_tmp = P.ExpandDims()(template_mask_tmp, -1) * \ + P.ExpandDims()(template_mask_tmp, 1) + + inv_distance_scalar_tmp = inv_distance_scalar_tmp * template_mask_2d_tmp + unit_vector_tmp = (P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[0], -1), + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[1], -1), + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[2], -1)) + + if not self.use_template_unit_vector: + unit_vector_tmp = (P.ZerosLike()(unit_vector_tmp[0]), + P.ZerosLike()(unit_vector_tmp[1]), + P.ZerosLike()(unit_vector_tmp[2])) + to_concat_temp = to_concat_temp + unit_vector_tmp + \ + (P.ExpandDims()(template_mask_2d_tmp, -1),) + act_tmp = P.Concat(-1)(to_concat_temp) + + act_tmp = act_tmp * P.ExpandDims()(template_mask_2d_tmp, -1) + act_tmp = self.embedding2d(act_tmp) + + act_tmp = P.Split(0, self.batch_block)(act_tmp) + act = () + for i in range(self.batch_block): + act = act + (P.Squeeze()(act_tmp[i]),) + + output = [] + slice_act = None + for i in range(self.batch_block): + act_batch = act[i] + if i > 0: + act_batch = F.depend(act_batch, slice_act) + for j in range(self.num_block): + act_batch = self.template_pair_stack[j](act_batch, mask_2d) + slice_act = P.Reshape()(act_batch, ((1,) + P.Shape()(act_batch))) + output.append(slice_act) + + act_tmp_loop = P.Concat()(output) + act_tmp = self.output_layer_norm(act_tmp_loop) + return act_tmp + + +class TemplateEmbedding(nn.Cell): + '''template embedding''' + + def __init__(self, config, is_training, mixed_precision=True): + super(TemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_channels = self.config.template_pair_stack.triangle_attention_ending_node.value_dim + self.template_embedder = SingleTemplateEmbedding(config, is_training, mixed_precision) + self.template_pointwise_attention = Attention(self.config.attention.num_head, + self.config.attention.key_dim, + self.config.attention.gating, + q_data_dim=128, m_data_dim=64, + output_dim=128, batch_size=None) + self.slice_num = config.slice.template_embedding + + def compute(self, flat_query, flat_templates, input_mask): + embedding = self.template_pointwise_attention(flat_query, flat_templates, + input_mask, index=None, + nonbatched_bias=None) + return embedding + + def construct(self, query_embedding, template_aatype, template_all_atom_masks, + template_all_atom_positions, template_mask, template_pseudo_beta_mask, + template_pseudo_beta, mask_2d): + '''construct''' + num_templates = template_mask.shape[0] + num_channels = self.num_channels + num_res = query_embedding.shape[0] + query_num_channels = query_embedding.shape[-1] + mask_2d = F.depend(mask_2d, query_embedding) + template_pair_representation = self.template_embedder(mask_2d, template_aatype, + template_all_atom_masks, + template_all_atom_positions, + template_pseudo_beta_mask, + template_pseudo_beta) + flat_query = P.Reshape()(query_embedding, (num_res * num_res, 1, query_num_channels)) + flat_templates = P.Reshape()( + P.Transpose()(template_pair_representation, (1, 2, 0, 3)), + (num_res * num_res, num_templates, num_channels)) + template_mask_bias = P.ExpandDims()(P.ExpandDims()(P.ExpandDims()(template_mask, + 0), 1), 2) - 1.0 + input_mask = 1e4 * template_mask_bias + batched_inputs = (flat_query, flat_templates) + nonbatched_inputs = (input_mask,) + embedding = _memory_reduce(self.compute, batched_inputs, nonbatched_inputs, self.slice_num) + embedding = P.Reshape()(embedding, (num_res, num_res, query_num_channels)) + + embedding = embedding * (P.ReduceSum()(template_mask) > 0.) + return embedding + + +class SingleTemplateEmbeddingAverage(nn.Cell): + '''single template embedding''' + + def __init__(self, config, is_training, mixed_precision): + super(SingleTemplateEmbeddingAverage, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_bins = self.config.dgram_features.num_bins + self.min_bin = self.config.dgram_features.min_bin + self.max_bin = self.config.dgram_features.max_bin + + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.embedding2d = nn.Dense(88, self.num_channels, + weight_init=lecun_init(88, initializer_name='relu')) + + template_layers = nn.CellList() + for _ in range(self.config.template_pair_stack.num_block): + template_pair_stack_block = TemplatePairStack(config) + if is_training: + template_pair_stack_block.recompute() + template_layers.append(template_pair_stack_block) + self.template_pair_stack = template_layers + + self.one_hot = nn.OneHot(depth=22, axis=-1) + self.n, self.ca, self.c = [atom_order[a] for a in ('N', 'CA', 'C')] + + self.use_template_unit_vector = self.config.use_template_unit_vector + layer_norm_dim = 64 + self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5) + self.num_block = self.config.template_pair_stack.num_block + self.batch_block = 4 + + self.template_pointwise_attention = Attention(self.config.attention.num_head, + self.config.attention.key_dim, + self.config.attention.gating, + q_data_dim=128, m_data_dim=64, + output_dim=128, batch_size=None) + self.slice_num = config.slice.template_embedding + + def construct(self, mask_2d, template_aatype, template_all_atom_masks, template_all_atom_positions, + template_pseudo_beta_mask, template_pseudo_beta, query_embedding, template_mask): + '''construct''' + + num_channels = self.num_channels + num_res = query_embedding.shape[0] + query_num_channels = query_embedding.shape[-1] + template_aatype_batch = P.Split(0, self.batch_block)(template_aatype) + template_all_atom_masks_batch = P.Split(0, self.batch_block)(template_all_atom_masks) + template_all_atom_positions_batch = P.Split(0, self.batch_block)(template_all_atom_positions) + template_pseudo_beta_mask_batch = P.Split(0, self.batch_block)(template_pseudo_beta_mask) + template_pseudo_beta_batch = P.Split(0, self.batch_block)(template_pseudo_beta) + template_mask_batch = P.Split(0, self.batch_block)(template_mask) + + + embedding_all = 0 + + for i in range(self.batch_block): + template_aatype = template_aatype_batch[i] + template_all_atom_masks = template_all_atom_masks_batch[i] + template_all_atom_positions = template_all_atom_positions_batch[i] + template_pseudo_beta_mask = template_pseudo_beta_mask_batch[i] + template_pseudo_beta = template_pseudo_beta_batch[i] + template_mask = template_mask_batch[i] + + template_aatype = F.depend(template_aatype, embedding_all) + template_all_atom_masks = F.depend(template_all_atom_masks, embedding_all) + template_all_atom_positions = F.depend(template_all_atom_positions, embedding_all) + template_pseudo_beta_mask = F.depend(template_pseudo_beta_mask, embedding_all) + template_pseudo_beta = F.depend(template_pseudo_beta, embedding_all) + template_mask = F.depend(template_mask, embedding_all) + + num_res = template_aatype[0, ...].shape[0] + template_mask_2d_temp = P.ExpandDims()(template_pseudo_beta_mask, -1) * \ + P.ExpandDims()(template_pseudo_beta_mask, 1) + template_dgram_temp = dgram_from_positions( + template_pseudo_beta, + self.num_bins, + self.min_bin, + self.max_bin, + self._type + ) + + to_concat_temp = (template_dgram_temp, P.ExpandDims()(template_mask_2d_temp, -1)) + aatype_temp = self.one_hot(template_aatype) + aatype_temp = P.Cast()(aatype_temp, self._type) + to_concat_temp = to_concat_temp + (P.Tile()(P.ExpandDims()(aatype_temp, 1), (1, num_res, 1, 1)), + P.Tile()(P.ExpandDims()(aatype_temp, 2), (1, 1, num_res, 1))) + + rot_temp, trans_temp = make_transform_from_reference( + template_all_atom_positions[:, :, self.n], + template_all_atom_positions[:, :, self.ca], + template_all_atom_positions[:, :, self.c]) + + _, rotation_tmp, translation_tmp = quat_affine(None, trans_temp, rot_temp) + points_tmp = [ + P.ExpandDims()(translation_tmp[0], -2), + P.ExpandDims()(translation_tmp[1], -2), + P.ExpandDims()(translation_tmp[2], -2) + ] + affine_vec_tmp = invert_point(points_tmp, rotation_tmp, translation_tmp, extra_dims=1) + inv_distance_scalar_tmp = P.Rsqrt()(1e-6 + P.Square()(affine_vec_tmp[0]) + P.Square()(affine_vec_tmp[1]) + \ + P.Square()(affine_vec_tmp[2])) + template_mask_tmp = ( + template_all_atom_masks[:, :, self.n] * + template_all_atom_masks[:, :, self.ca] * + template_all_atom_masks[:, :, self.c]) + template_mask_2d_tmp = P.ExpandDims()(template_mask_tmp, -1) * P.ExpandDims()(template_mask_tmp, 1) + + inv_distance_scalar_tmp = inv_distance_scalar_tmp * template_mask_2d_tmp + unit_vector_tmp = ( + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[0], -1), + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[1], -1), + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[2], -1) + ) + + if not self.use_template_unit_vector: + unit_vector_tmp = ( + P.ZerosLike()(unit_vector_tmp[0]), + P.ZerosLike()(unit_vector_tmp[1]), + P.ZerosLike()(unit_vector_tmp[2]) + ) + to_concat_temp = to_concat_temp + unit_vector_tmp + (P.ExpandDims()(template_mask_2d_tmp, -1),) + act_slice = P.Concat(-1)(to_concat_temp) + act_slice = act_slice * P.ExpandDims()(template_mask_2d_tmp, -1) + + act_slice = P.Squeeze()(act_slice) + if i > 0: + act_slice = F.depend(act_slice, embedding_all) + + act_slice = self.embedding2d(act_slice) + for j in range(self.num_block): + act_slice = self.template_pair_stack[j](act_slice, mask_2d) + + act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice))) + act_slice = self.output_layer_norm(act_slice) + + query_embedding = F.depend(query_embedding, act_slice) + flat_query = P.Reshape()(query_embedding, (num_res * num_res, 1, query_num_channels)) + flat_templates = P.Reshape()( + P.Transpose()(act_slice, (1, 2, 0, 3)), + (num_res * num_res, 1, num_channels)) + + template_mask_bias = P.ExpandDims()(P.ExpandDims()(P.ExpandDims()(template_mask, 0), 1), 2) - 1.0 + input_mask = 1e4 * template_mask_bias + + if self.slice_num: + slice_shape = (self.slice_num, -1) + flat_query_shape = P.Shape()(flat_query) + flat_query = P.Reshape()(flat_query, slice_shape + flat_query_shape[1:]) + flat_templates_shape = P.Shape()(flat_templates) + flat_templates = P.Reshape()(flat_templates, slice_shape + flat_templates_shape[1:]) + slice_idx = 0 + embedding_tuple = () + embedding_slice = None + while slice_idx < self.slice_num: + flat_query_slice_ = flat_query[slice_idx] + flat_templates_slice_ = flat_templates[slice_idx] + if slice_idx > 0: + flat_query_slice_ = F.depend(flat_query_slice_, embedding_slice) + flat_templates_slice_ = F.depend(flat_templates_slice_, embedding_slice) + embedding_slice = self.template_pointwise_attention(flat_query_slice_, flat_templates_slice_, + input_mask, index=None, nonbatched_bias=None) + embedding_slice = P.Reshape()(embedding_slice, ((1,) + P.Shape()(embedding_slice))) + embedding_tuple = embedding_tuple + (embedding_slice,) + slice_idx += 1 + embedding = P.Concat()(embedding_tuple) + + embedding = P.Reshape()(embedding, (num_res, num_res, query_num_channels)) + + embedding = embedding * (P.ReduceSum()(template_mask) > 0.) + embedding_all += embedding * 0.25 + else: + embedding = self.template_pointwise_attention( + flat_query, + flat_templates, + input_mask, + index=None, + nonbatched_bias=None + ) + embedding = P.Reshape()(embedding, (num_res, num_res, query_num_channels)) + embedding_all += embedding * 0.25 + + return embedding_all + + +class TemplateEmbeddingAverage(nn.Cell): + '''template embedding''' + + def __init__(self, config, is_training, mixed_precision=True): + super(TemplateEmbeddingAverage, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.template_embedder = SingleTemplateEmbeddingAverage(config, is_training, mixed_precision) + + + def construct(self, query_embedding, template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, mask_2d): + '''construct''' + mask_2d = F.depend(mask_2d, query_embedding) + embedding = self.template_embedder( + mask_2d, + template_aatype, + template_all_atom_masks, + template_all_atom_positions, + template_pseudo_beta_mask, + template_pseudo_beta, + query_embedding, + template_mask + ) + + embedding = embedding * (P.ReduceSum()(template_mask) > 0.) + return embedding diff --git a/MindSPONGE/applications/research/Grasp/restraint_sample.py b/MindSPONGE/applications/research/Grasp/restraint_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..77055526caef77417e14249c3bd804ea8e08a615 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/restraint_sample.py @@ -0,0 +1,275 @@ +import numpy as np +import traceback +# from scipy.stats import lognorm +BINS = np.arange(4, 33, 1) + +def normalize_number_in_bins(dist, bins): + upper_edges = np.array(list(bins) + [np.inf]) + lower_edges = np.array([0] + list(bins)) + num_in_bins = ((dist.flatten()<=upper_edges[..., None])*(dist.flatten()>lower_edges[..., None])).sum(-1) + dist_which_bin = (dist[..., None]>bins).sum(-1) + p_norm = (1/(num_in_bins+1e-8))[dist_which_bin] + return p_norm + + +# def sample_dist_1(dist, mask, num, thre, fdr, bins=BINS): +# d = dist * mask +# idx = np.where(d>1) # remove masked sites +# d = d[idx] +# true_num = (d<=thre).sum() +# total_max = np.ceil(true_num/(1-fdr)) +# num = int(min(total_max, num, len(d))) +# p_norm = normalize_number_in_bins(d, bins) +# p = fdr * (d>thre)/(bins>=thre).sum() + (1-fdr)*(d<=thre)/(bins<=thre).sum() +# p *= p_norm +# p /= p.sum() +# chosen_idx = np.random.choice(np.arange(p.size), size=num, p=p.ravel(), replace=False) + +# idx = np.transpose(idx)[chosen_idx] +# idx = (idx[:,0], idx[:, 1]) + +# return idx + +def sample_dist(dist, mask, num, thre, fdr, bins=BINS): + d = dist * mask + idx = np.where(d>1) # remove masked sites + d = d[idx] + + true_num = (d<=thre).sum() + false_num = d.size - true_num + + true_num = int(min(true_num, round(num*(1-fdr)))) + false_num = int(min(false_num, round(num*fdr))) + + p_norm = normalize_number_in_bins(d, bins) + + chosen_idx = [] + # sample true + if true_num>0: + p = np.ones(d.shape) * (d<=thre) + p *= p_norm + p /= p.sum() + idx_temp = np.random.choice(np.arange(p.size), size=true_num, p=p.ravel(), replace=False) + chosen_idx.extend(idx_temp) + + # sample false + if false_num>0: + p = np.ones(d.shape) * (d>thre) + p *= p_norm + p /= p.sum() + idx_temp = np.random.choice(np.arange(p.size), size=false_num, p=p.ravel(), replace=False) + chosen_idx.extend(idx_temp) + + idx = np.transpose(idx)[chosen_idx] + idx = (idx[:,0], idx[:, 1]) + + return idx + +def get_crop_index(asym_id, residue_index, chain_index): + unique_chain_index = np.unique(chain_index) + unique_chain_index.sort() + crop_index = [] + seq_len = 0 + for i in unique_chain_index: + # print(i, seq_len) + res_idx_i = residue_index[asym_id == (i+1)] + res_idx_i += seq_len + crop_index.extend(res_idx_i) + seq_len += (chain_index == i).sum() + # print(crop_index) + return crop_index + +def get_sample_num(start, reduce, end, k): + probs = np.concatenate((np.zeros(start), np.ones(reduce-start), np.exp(-(np.arange(end-reduce)/k))),axis=0) + probs = normalize(probs) + num = np.random.choice(np.arange(end), p=probs) + return num + +def normalize(probs): + probs[probs<0] = 0 + probs /= probs.sum() + return probs + +def generate_mask(dist, pseudo_beta_mask, asym_id, residue_index): + ''' add mask info''' + remote_residue_threshold = 6 + greater_residue_index_diff = (np.abs(residue_index[:, None] - residue_index[None]) > remote_residue_threshold).astype(np.float32) + pseudo_beta_mask_2d = (pseudo_beta_mask[:,None] * pseudo_beta_mask[None]) * (dist > 0.01) + upper_mask = np.triu(np.ones_like(pseudo_beta_mask_2d), 1) + mask_intra = (asym_id[:, None] == asym_id[None]).astype(np.float32) + mask_inter = (1.0 - mask_intra) * pseudo_beta_mask_2d + mask_interface = (dist < 8.0) * mask_inter + interface_dist = (dist+(1-mask_inter)*1e8).min(-1) + mask_interface = mask_interface.any(-1) + mask_inter *= upper_mask + mask_intra = mask_intra * greater_residue_index_diff * pseudo_beta_mask_2d * upper_mask + return mask_inter, mask_intra, mask_interface, interface_dist + +def sample_interface_by_asym_id(asym_id, mask, num): + num = min(num, mask.sum()) + mask_interface = np.zeros_like(mask) + if num > 0: + asym_id_same = asym_id[:, None] == asym_id[None] + num_interface_each_chain = (asym_id_same * mask[None]).sum(-1) + probs = mask / (num_interface_each_chain + 1e-8) + probs = normalize(probs) + idx = np.random.choice(np.arange(len(asym_id)), num, replace=False, p=probs) + mask_interface[idx] = 1.0 + return mask_interface + +def single_bin(dist, fdr, bins): + r = np.eye(len(bins)+1)[(dist[..., None] > bins).sum(-1).astype(np.int32)] + r = r*(1-fdr) + (1-r)*fdr/((1-r).sum(-1, keepdims=True)) + return r + +def uniform_cutoff(thre, fdr, bins): + r = np.ones((len(bins)+1)) + num_lower = (thre >= bins).sum() + r[:num_lower] = r[:num_lower]/r[:num_lower].sum() * (1-fdr) + r[num_lower:] = r[num_lower:]/r[num_lower:].sum() * fdr + return r + +def print_rpr(dist, mask, sbr, thre): + if mask.sum()>0: + d = dist[mask>0.5] + print(f'Total:{d.size}, FDR: {(d>thre).sum()/d.size}, Thre: {thre}, Dist: {d}') + # if not (sbr[mask>0.5].sum(-1)==1).all(): + # print(sbr[mask>0.5].sum(-1)) + # assert (sbr[mask>0.5].sum(-1)==1).all() + else: + print('No restraint') + + +def generate_interface_and_restraints(d, num_inter=0, num_intra=0, num_interface=0, thre=8, fdr=0.05, + mixed_precision=True, training=True, + seed=None, fix_afm=True, bins = BINS): + if seed is not None: + np.random.seed(seed) + + # assert 'pseudo_beta' in d + # assert 'pseudo_beta_mask' in d + # assert 'asym_id' in d + # assert 'residue_index' in d + # assert 'chain_index' in d + + if training: + asym_id = d['asym_id'][0] + residue_index = d['residue_index'][0] + chain_index = d['chain_index'] + crop_index = get_crop_index(asym_id, residue_index, chain_index) + seqlen = len(asym_id) #384 + + # check crop index + aatype_pdb = d['aatype_per_chain'][crop_index] + aatype_pdb = np.pad(aatype_pdb, ((0, seqlen - aatype_pdb.shape[0]),)) + aatype = d['aatype'][0] + delta = (np.abs(aatype - aatype_pdb) * (aatype_pdb < 20)).sum() + if delta > 0: + print('error! crop index is wrong!') + print(aatype) + print(aatype_pdb) + raise ValueError + + pseudo_beta = d["pseudo_beta"][crop_index] + pseudo_beta_mask = d['pseudo_beta_mask'][crop_index] + # pad to fixed length + pseudo_beta = np.pad(pseudo_beta, ((0, seqlen - pseudo_beta.shape[0]), (0, 0))) + pseudo_beta_mask = np.pad(pseudo_beta_mask, ((0, seqlen - pseudo_beta_mask.shape[0]),)) + dist = np.sqrt((np.square(pseudo_beta[None]-pseudo_beta[: ,None])).sum(-1) + 1e-8) + + else: + asym_id = d['asym_id'] + seqlen = len(asym_id) + pseudo_beta = d['pseudo_beta'] if 'pseudo_beta' in d else np.zeros((seqlen, 3)) + if 'pseudo_beta_mask' in d: + pseudo_beta_mask = d['pseudo_beta_mask'] + elif 'mask_2d' in d: + pseudo_beta_mask = (d['mask_2d'].sum(0) > 0.5).astype(d['mask_2d'].dtype) + else: + np.ones_like(asym_id) + dist = np.sqrt((np.square(pseudo_beta[None]-pseudo_beta[: ,None])).sum(-1) + 1e-8) + dist = d['dist'] if 'dist' in d else dist + residue_index = d['residue_index'] if 'residue_index' in d else np.arange(seqlen) + + + sbr = np.zeros((seqlen, seqlen, len(bins) + 1)) + sbr_mask = np.zeros((seqlen, seqlen)) + mask_interface = np.zeros(seqlen) + try: + + if training: + + num_inter = 0 + num_intra = 0 + num_interface = 0 + + sample_ratio = 0.5 + if np.random.rand() < sample_ratio: + num_inter = get_sample_num(start=1, reduce=20, end=40, k=4) + + if np.random.rand() < sample_ratio: + num_intra = get_sample_num(start=1, reduce=80, end=160, k=16) + + if np.random.rand() < sample_ratio: + num_interface = get_sample_num(start=1, reduce=40, end=80, k=8) + + if fix_afm and num_inter+num_intra+num_interface==0: + num_inter = get_sample_num(start=1, reduce=20, end=40, k=4) + num_intra = get_sample_num(start=1, reduce=80, end=160, k=16) + num_interface = get_sample_num(start=1, reduce=40, end=80, k=8) + + # Only one chain + if len(np.unique(asym_id)) == 1: + num_interface = 0 + num_inter = 0 + + mask_inter, mask_intra, mask_interface, interface_dist = generate_mask(dist, pseudo_beta_mask, asym_id, residue_index) + + if training: + single_bin_ratio = 0.5 + if np.random.rand() < single_bin_ratio: + thre = 30 + r = single_bin(dist, fdr=0.05, bins=bins) + + else: + thre = np.random.randint(low=8, high=31) + r = uniform_cutoff(thre, fdr=fdr, bins=bins) + r = np.tile(r, (*dist.shape, 1)) + + else: + r = uniform_cutoff(thre, fdr=fdr, bins=bins) + r = np.tile(r, (*dist.shape, 1)) + + intra_pair = sample_dist(dist, mask_intra, num_intra, thre=thre, fdr=fdr, bins=bins) + inter_pair = sample_dist(dist, mask_inter, num_inter, thre=thre, fdr=fdr, bins=bins) + sbr[intra_pair] = r[intra_pair] + sbr[inter_pair] = r[inter_pair] + sbr += sbr.swapaxes(0, 1) + + mask_interface = sample_interface_by_asym_id(asym_id, mask_interface, num_interface) + + dtype = np.float32 + if mixed_precision: + dtype = np.float16 + sbr = sbr.astype(dtype) + sbr_mask = (sbr.sum(-1) > 0.5).astype(dtype) + + mask_interface = mask_interface.astype(dtype) + + # show info + print('inter rpr: =======================================') + print_rpr(dist, mask_inter*sbr_mask, sbr, thre) + print('intra rpr: =======================================') + print_rpr(dist, mask_intra*sbr_mask, sbr, thre) + print('interface: =======================================') + print(f'Total: {int(mask_interface.sum())}, Dist: {interface_dist[mask_interface>0.5]}') + + except Exception as e: + sbr = np.zeros((seqlen, seqlen, len(bins) + 1)) + sbr_mask = np.zeros((seqlen, seqlen)) + mask_interface = np.zeros(seqlen) + print('Error in sample restraints:', e) + traceback.print_exc() + + return sbr, sbr_mask, mask_interface + diff --git a/MindSPONGE/applications/research/Grasp/utils_infer.py b/MindSPONGE/applications/research/Grasp/utils_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6d33335b9a53c1934d7a138eda53c6b71a93ac --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/utils_infer.py @@ -0,0 +1,1066 @@ +import os +import re +import gc +import time +import glob +import pickle +import datetime +import numpy as np +import pandas as pd +import mindspore.context as context +import mindspore.numpy as mnp +import mindspore.common.dtype as mstype + +from model import MegaFold +from model import compute_confidence, compute_ranking_score +from common.protein import to_pdb, from_prediction +from common.protein import to_pdb, from_prediction, PDB_CHAIN_IDS +from common.utils import trans_ckpt +from data import MultimerFeature +from restraint_sample import BINS + +from mindspore import Tensor +from mindspore import load_checkpoint, nn, load_param_into_net +from mindspore.communication import init, get_rank + +from mindsponge1.common.protein import from_pdb_string_all_chains +from mindsponge1.data.data_transform import pseudo_beta_fn +from mindsponge1.cell.amp import amp_convert +from mindsponge1.common.config_load import load_config +from mindsponge1.common import residue_constants +# import mindspore.train.amp as amp +# from mindspore.ops import operations as P +# amp.AMP_AUTO_BLACK_LIST.clear() +# amp.AMP_AUTO_BLACK_LIST.extend([P.LayerNorm, P.Softmax]) +# amp.AMP_AUTO_WHITE_LIST.extend([P.BiasAdd]) +SEED = 20230820 + + +# def get_seqs_from_fasta(fasta): +# with open(fasta, 'r') as f: +# cont = [i.strip() for i in f.readlines()] +# seqs = cont[1::2] +# return seqs + +def parse_fasta(fasta): + with open(fasta, 'r') as f: + cont = [i.strip() for i in f.readlines()] + seqdict = {} + desc = None + for line in cont: + if line.startswith('>'): + if desc is not None: + seqdict[desc] = seq + seq = '' + desc = line[1:].strip() + else: + seq += line + seqdict[desc] = seq + return seqdict + +def get_mapping_from_fasta(fasta): + mapping = parse_fasta(fasta) + mapping = {k.split('_')[-1][0]: v for k, v in mapping.items()} + return mapping + # with open(fasta, 'r') as f: + # cont = [i.strip() for i in f.readlines()] + # mapping = {} + # for i in range(0, len(cont), 2): + # k = cont[i].split('_')[-1][0] + # v = cont[i+1] + # mapping[k] = v + # print(mapping) + # return mapping + +def get_order_from_seqs(seqs): + seqdict = {} + for seq in seqs: + if seq not in seqdict: + seqdict[seq] = 1 + else: + seqdict[seq] += 1 + + p = 0 # pointer + for seq, k in seqdict.items(): + ls = [] + for i in range(k): + ls.append(range(p, p+len(seq))) + p += len(seq) + seqdict[seq] = ls + ls = [] + for seq in seqs: + ls.append(seqdict[seq][0]) + seqdict[seq].pop(0) + return ls + +def reorder(x, slices, axis): + return np.concatenate([np.take(x, i, axis) for i in slices], axis=axis) + +def reorder_features(feats, seqs): + ord = get_order_from_seqs(seqs) + seqlen = feats['aatype'].shape[0] + for k, v in feats.items(): + # print(k, v.shape) + for i, s in enumerate(v.shape): + if s == seqlen: + v = reorder(v, ord, i) + feats[k] = v + + +def np_pad(array, seqlen, axis=None): + pad_width = [] + if isinstance(axis, int): + axis = (axis,) + if axis is None: + axis = range(len(array.shape)) + # print("=================array===================: ", array.shape) + for i, n in enumerate(array.shape): + if i in axis: + pad_width.append((0, seqlen - n)) + else: + pad_width.append((0, 0)) + return np.pad(array=array, pad_width=pad_width) + +def get_dist_from_protein(prot): + pseudo_beta, pseudo_beta_mask = pseudo_beta_fn(prot.aatype, prot.atom_positions, prot.atom_mask) + pred_dist = np.sqrt(((pseudo_beta[:, None] - pseudo_beta[None]) ** 2).sum(-1) + 1e-8) + pseudo_beta_mask_2d = pseudo_beta_mask[:, None] * pseudo_beta_mask[None] + return pred_dist, pseudo_beta_mask_2d + +def get_nbdist_avg_ca(prot, asym_id, break_thre=5.0): + """compute averaged neihbour ca distance for each residue""" + # atom_types = [ + # 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + # 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + # 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + # 'CZ3', 'NZ', 'OXT' + # ] + ca_idx = 1 + ca_pos = prot.atom_positions[..., ca_idx, :] #[nres, natom, 3] + mask = prot.atom_mask[..., ca_idx] + nbdist = np.sqrt(((ca_pos[1:]-ca_pos[:-1])**2).sum(-1)+1e-8) + mask_nbdist = 4.0 + for i in np.where(1-mask)[0]: + print(i) + nbdist[i] = mask_nbdist + if i>0: + nbdist[i-1] = mask_nbdist + nbdist_leftadd = np.concatenate([[nbdist[0]], nbdist]) + nbdist_rightadd = np.concatenate([nbdist, [nbdist[-1]]]) + is_chain_start = asym_id!=np.concatenate(([[-1], asym_id[:-1]])) + is_chain_end = asym_id!=np.concatenate((asym_id[1:], [100000])) + nbdist_left = np.where(is_chain_start, nbdist_rightadd, nbdist_leftadd) + nbdist_right = np.where(is_chain_end, nbdist_leftadd, nbdist_rightadd) + nbdist_avg = (nbdist_left+nbdist_right)/2 + + break_num = int((nbdist_left>break_thre).sum()) + max_nb_dist = nbdist_left.max() + + return nbdist_avg, break_num, max_nb_dist + + +def dist_onehot(dist, bins): + x = (dist[..., None] > bins).sum(-1) + return np.eye(len(bins) + 1)[x] + +def get_range(x): + lowers = np.concatenate([[0], BINS]) + uppers = np.concatenate([BINS, [np.inf]]) + intervals = [(i, j) for i, j, k in zip(lowers, uppers, x) if k] + ys = [] + last = None + for i, j in intervals: + if (last is not None) and (last == i): + ys[-1][-1] = j + else: + ys.append([i, j]) + last = j + return ','.join([f'{i}-{j}' for i, j in ys]) + + +def compute_recall(satis, mask, conf): + if mask.sum() <= 0: + return None, None + + recall = (satis*mask).sum()/(mask.sum()+1e-8) + recall_conf = (satis*mask*conf).sum()/((mask*conf).sum()) + return recall, recall_conf + +def compute_rm_score(values, thres): + score = 0 + scale = 1 + assert len(values) == len(thres), (values, thres) + for value, thre in zip(values, thres): + if (thre is not None): + if (value>thre): + score += (value-thre)*scale + scale *= 100 + return score + + +def generate_terminal_mask(asym_id, n): + is_end = asym_id != np.concatenate([asym_id[1:], [-1]]) + is_start = asym_id != np.concatenate([[0], asym_id[: -1]]) + end_idx = np.where(is_end)[0] + start_idx = np.where(is_start)[0] + term_idx = np.concatenate([end_idx, start_idx]) + idx = np.arange(len(asym_id)) + mask = (np.abs(idx[:, None] - term_idx[None]) >= n).all(axis=-1).astype(int) + mask = mask[None] # only mask the other side which is different from the interface. + return mask + + +def filter_restraints(restraints, restraints0, prot, nbdist_ca_thre=5, max_rm_ratio=0.2, viol_thre=5, mask_terminal_residues=0): + # restraints0: initial restraints. + # restraints: current restraints. + + plddt = prot.b_factors.max(-1) + pred_dist, pseudo_beta_mask_2d = get_dist_from_protein(prot) + mask_intrachain = restraints['asym_id'][None] == restraints['asym_id'][:, None] + terminal_residue_mask = generate_terminal_mask(restraints['asym_id'], mask_terminal_residues) + + d = pred_dist + mask_intrachain*1000 + (1-pseudo_beta_mask_2d) * 1000 + (1-terminal_residue_mask) * 1000 + # dist_thre=10.0 + # plddts_2d = (d<=dist_thre)*plddt[None] + # plddt_otherside = plddts_2d.max(axis=1) + sbr = restraints['sbr'] + sbr_high = (sbr > (1 / sbr.shape[-1])) + + not_high_bin = 1-sbr_high + upper_1d = np.concatenate([BINS, [100,]]) + sbr_upper_thre = (upper_1d-1e6*not_high_bin).max(-1) + sbr_upper_viol_dist = (pred_dist-sbr_upper_thre) + sbr_max_viol_dist = (sbr_upper_viol_dist * restraints['sbr_mask']).max() + sbr_viol_num = ((sbr_upper_viol_dist * restraints['sbr_mask']) > 0).sum() / 2 + interface_viol_dist = ((d.min(axis=-1)-8.0)*restraints['interface_mask']) + interface_max_viol_dist = interface_viol_dist.max() + interface_viol_num = (interface_viol_dist>0).sum() + viol_num = sbr_viol_num + interface_viol_num + max_viol_dist = max(sbr_max_viol_dist, interface_max_viol_dist) + pred_dist_onehot = dist_onehot(pred_dist, BINS) + sbr_satis = (sbr_high * pred_dist_onehot).sum(-1) * pseudo_beta_mask_2d + nbdist_avg_ca, break_num, max_nb_dist = get_nbdist_avg_ca(prot, asym_id=restraints['asym_id']) + includ_mat = np.zeros_like(restraints['sbr_mask']) + includ_if = np.zeros_like(restraints['interface_mask']) + + + def resi(i, ds=None): + cid = PDB_CHAIN_IDS[int(restraints['asym_id'][i])-1] + rid = prot.residue_index[i] + y = f'{cid}{rid}/conf{plddt[i]:.2f}/nbdist_avg_ca{nbdist_avg_ca[i]:.2f}' + if ds is not None: + y += f'/dist_cb{ds[i]:.2f}' + return y + + def print_pair(ps): + ps = [(i, j) for i, j in ps if i{resi(j, pred_dist[i])}, range: {get_range(sbr_high[i,j])}, rm_score {rm_score}, rm_thre {rm_thre}') + print(f'>>>>> Total {len(ps)}: {included_num} included, {satisfied_num} satisfied') + + # print interface info ========================================================== + if_num = int(restraints['interface_mask'].sum()) + if if_num>0: + print('interface restraints:') + included_num = 0 + satisfied_num = 0 + nbdists = [nbdist_avg_ca[i] for i in np.where(restraints['interface_mask'])[0]] + viol_dists = [d[i].min()-8.0 for i in np.where(restraints['interface_mask'])[0]] + rm_scores = [compute_rm_score((viol_dist, nb_dist), (viol_thre, nbdist_ca_thre)) for nb_dist, viol_dist in zip(nbdists, viol_dists)] + rm_thre = np.quantile(rm_scores, 1-max_rm_ratio) + for i, rm_score in zip(np.where(restraints['interface_mask'])[0], rm_scores): + # js = np.where((plddts_2d[i])>0)[0] + if d[i].min()<=8.0: + satisfied_num += 1 + satis_info = 'Satisfied!' + else: + satis_info = 'Violated! ' + + # if len(js)==0: + # print(f'Excluded! {satis_info} {resi(i)}<==>{resi(np.argmin(ds), ds)}') + # else: + # jmax = np.argmax(plddts_2d[i]) + + if (rm_score<=rm_thre): + includ_if[i] = 1 + included_num += 1 + filter_info = 'Included!' + else: + filter_info = 'Excluded!' + print(f'{filter_info} {satis_info} {resi(i)} {d[i].min()}, rm_score{rm_score}, rm_thre{rm_thre}') + + print(f'>>>>> Total {if_num}, {included_num} included, {satisfied_num} satisfied') + + # print sbr info ================================================================= + intra_ps = np.transpose(np.where(restraints['sbr_mask']*mask_intrachain)) + inter_ps = np.transpose(np.where(restraints['sbr_mask']*(1-mask_intrachain))) + intra_sbr = int(len(intra_ps)/2) + inter_sbr = int(len(inter_ps)/2) + tot_sbr = intra_sbr+inter_sbr + if tot_sbr >0: + print(f'inter-residue restraints: {tot_sbr}({inter_sbr} inter-chain + {intra_sbr} intra-chain)') + if inter_sbr > 0: + print('Inter-chain restraints') + print_pair(inter_ps) + if intra_sbr > 0: + print('Intra-chain restraints') + print_pair(intra_ps) + + # update restraints based on plddts ============================================== + tot_before = int(tot_sbr+if_num) + restraints['interface_mask'] = includ_if * restraints['interface_mask'] + restraints['sbr_mask'] = includ_mat * restraints['sbr_mask'] + restraints['sbr'] = restraints['sbr'] * restraints['sbr_mask'][:,:,None] + tot_after = int((restraints['interface_mask']).sum() + (restraints['sbr_mask']).sum()/2) + rm_num = int(tot_before - tot_after) + + # compute recall, breakage + sbr_mask0 = restraints0['sbr_mask'] + sbr0 = restraints0['sbr'] + sbr_high0 = (sbr0 > (1 / sbr0.shape[-1])) + sbr_satis0 = (sbr_high0 * pred_dist_onehot).sum(-1) * pseudo_beta_mask_2d + + + + + interface_mask0 = restraints0['interface_mask'] + interface_satis0 = d.min(axis=1)<=8 + conf_2d = (plddt[None]+plddt[:, None])/2 + + recall_dict = { + 'interchain': (*compute_recall(sbr_satis0, sbr_mask0*np.triu(sbr_mask0)*(1-mask_intrachain), conf_2d), 1), + 'intrachain': (*compute_recall(sbr_satis0, sbr_mask0*np.triu(sbr_mask0)*mask_intrachain, conf_2d), 0.5), + 'interface': (*compute_recall(interface_satis0, interface_mask0, plddt), 1) + } + + recall_dict = { + k: v for k, v in recall_dict.items() if v[0] is not None + } + + + print('Breakage info ==========') + print(f'Break number: {break_num}, Max neighbour CA dist: {max_nb_dist}\n') + + print('Recall info=============') + recall = 0 + recall_conf = 0 + w = 0 + + for k, v in recall_dict.items(): + if v[0] is None: + continue + print(f'{k} (w {v[2]}): recall {v[0]}, recall weighted by confidence: {v[1]}') + recall += v[0]*v[2] + recall_conf += v[1]*v[2] + w += v[2] + + if w == 0: + # no restraints + recall = None + recall_conf = None + else: + recall /= w + recall_conf /= w + + return rm_num, break_num, max_nb_dist, recall, recall_conf, viol_num, max_viol_dist + + + + + + +def generate_split_first_num(split_file): + with open(split_file, 'r') as f: + split = [i.strip().split(',') for i in f.readlines()] + print(split) + return len(split[0]) + +def generate_index(lenlsls): + ls = [] + for lenls in lenlsls: + last = 0 + for l in lenls: + a = np.arange(l)+last + ls.append(a) + last = a[-1]+200 + return np.concatenate(ls, axis=0) + +def dict_update_keepdtype(d1, d2): + for k, v in d2.items(): + if k in d1: + d1[k] = v.astype(d1[k].dtype) + +def generate_id(seqs, first_num): + s1, s2 = [''.join(s) for s in [seqs[:first_num], seqs[first_num:]]] + if s1 == s2: + entity_id = np.repeat(1, len(s1)+len(s2)) + sym_id = np.repeat([1, 2], (len(s1), len(s2))) + else: + entity_id = np.repeat([1, 2], (len(s1), len(s2))) + sym_id = np.repeat(1, len(s1)+len(s2)) + asym_id = np.repeat([1, 2], (len(s1), len(s2))) + return asym_id, sym_id, entity_id + + +def update_feature_make_two_chains(feat, first_num, seqs): + lenls = np.array([len(i) for i in seqs]) + lenlsls = [lenls[:first_num], lenls[first_num:]] + + asym_id, sym_id, entity_id = generate_id(seqs, first_num) + d_update = { + 'residue_index': generate_index(lenlsls), + 'asym_id': asym_id, + 'sym_id': sym_id, + 'entity_id': entity_id, + 'assembly_num_chains': np.array(2) + } + dict_update_keepdtype(feat, d_update) + +def get_distri(cutoff, fdr): + xbool = np.concatenate([BINS, [np.inf]])<=cutoff + x = np.ones(len(BINS)+1) + x[xbool] = (1-fdr) * (x[xbool]/x[xbool].sum()) + x[~xbool] = fdr * (x[~xbool]/x[~xbool].sum()) + assert x[xbool].max() > x[~xbool].max(), (x[xbool].max(), x[~xbool].max()) + return x + +class SplitNamelist: + def __init__(self, rank_id, rank_size, outdir, rotate_split=False, key=None): + self.rank_id = rank_id + self.rank_size = rank_size + self.rotate_split = rotate_split + self.outdir = outdir + self.key = key + os.makedirs(self.outdir, exist_ok=True) + self.completion_flag_file = self.get_flag(self.rank_id) + + def get_flag(self, rank_id): + if self.key is None: + completion_flag_file = f'{self.outdir}/.complete_flag_rank{rank_id}.tmp' + else: + completion_flag_file = f'{self.outdir}/.complete_flag_rank{rank_id}_{self.key}.tmp' + return completion_flag_file + + def split_namelist(self, namelist): + if not self.rotate_split: + d, m = divmod(len(namelist), self.rank_size) + nums = np.repeat([d+1, d], [m, self.rank_size-m]) + start = int(nums[:self.rank_id].sum()) + namelist_slice = namelist[start: start+nums[self.rank_id]] + else: + namelist_slice = [] + for i in range(self.rank_id, len(namelist), self.rank_size): + namelist_slice.append(namelist[i]) + print(f'Rank {self.rank_id}/{self.rank_size}: {len(namelist_slice)}/{len(namelist)}: {namelist_slice[:2]} ...', flush=True) + return namelist_slice + + def start_job(self): + print(f'start job completion monitor for rank id {self.rank_id}') + if os.path.isfile(self.completion_flag_file): + os.remove(self.completion_flag_file) + assert (not os.path.exists(self.completion_flag_file)) + + def complete(self): + print(f'job complete for rank id {self.rank_id}') + print(f'generate temporary complete flag file {self.completion_flag_file}') + with open(self.completion_flag_file, 'w') as f: + f.write(f'job complete for rank {self.rank_id}') + + def check_all_complete(self): + comp_files = [i for i in range(self.rank_size) if os.path.exists(self.get_flag(i))] + not_finish = list(set(range(self.rank_size)) - set(comp_files)) + not_finish.sort() + print(f'current completion status: {len(comp_files)}/{self.rank_size}, not finished: {not_finish}') + return len(comp_files) == self.rank_size + + + +class DataGenerator: + def __init__(self, raw_feat_dir, fasta_dir=None, reorder=False): + self.raw_feat_dir = raw_feat_dir + self.fasta_dir = fasta_dir + self.reorder = reorder + self.files_dict = {} + + def get_pattern(self, pdb_id): + pat_dict = { + 'raw_feat': f'{self.raw_feat_dir}/{pdb_id}*.pkl', + 'fasta': f'{self.fasta_dir}/{pdb_id}*.fasta' + } + return pat_dict + + def _glob_file(self, pattern): + files = glob.glob(pattern) + assert len(files)<=1, files + return files[0] if len(files)==1 else None + + def get_files(self, pdb_id): + if pdb_id in self.files_dict: + return self.files_dict[pdb_id] + pat_dict = self.get_pattern(pdb_id) + file_dict = {k: self._glob_file(v) for k, v in pat_dict.items()} + self.files_dict[pdb_id] = file_dict + return file_dict + + def get_feat(self, pdb_id): + # raw feat + raw_pkl = self.get_files(pdb_id)['raw_feat'] + with open(raw_pkl, "rb") as f: + raw_feature = pickle.load(f) + if self.reorder: + print('reorder features') + seqs = list(self.get_seqs_dict(pdb_id).values()) + reorder_features(raw_feature, seqs) + return raw_feature + + def get_len(self, pdb_id): + raw_feat = self.get_feat(pdb_id) + return raw_feat['msa'].shape[1] if raw_feat is not None else 100000 + + def get_seqs_dict(self, pdb_id): + fasta_file = self.get_files(pdb_id)['fasta'] + return parse_fasta(fasta_file) + + def get_data(self): + # overwrite for specific cases + raise NotImplementedError + +class ModelGenerator: + def __init__(self, arguments, ckpt_dir): + data_cfg = load_config(arguments.data_config) + model_cfg = load_config(arguments.model_config) + # print("this is model_cfg: ", model_cfg) + self.seq_length = int(arguments.seq_len) + data_cfg.eval.crop_size = self.seq_length + model_cfg.seq_length = self.seq_length + slice_key = "seq_" + str(model_cfg.seq_length) + slice_val = vars(model_cfg.slice)[slice_key] + model_cfg.slice = slice_val + data_cfg.common.target_feat_dim = 21 # TARGET_FEAT_DIM + model_cfg.common.target_feat_dim = 21 # TARGET_FEAT_DIM + self.arguments = arguments + self.model_cfg = model_cfg + self.data_cfg = data_cfg + self.processed_feature = MultimerFeature(arguments.mixed_precision) + # ckpt + self.ckpt_dir = ckpt_dir + if not os.path.exists(self.ckpt_dir): + raise ValueError(f'checkpoint directory {self.ckpt_dir} does not exist') + self.last_ckpt = None + if os.path.isdir(self.ckpt_dir): + self.ckpt = None + else: + self.ckpt = self.ckpt_dir + + def get_ckpt(self, ckpt_id): + ckpt = f'{self.ckpt_dir}/step_{ckpt_id}.ckpt' + if not os.path.isfile(ckpt): + ckpt = f'{self.ckpt_dir}/{ckpt_id}.ckpt' + return ckpt + + def get_model(self, ckpt_id): + if self.ckpt is not None: + print(f'loading model from {self.ckpt}, not use ckpt id{ckpt_id}') + ckpt = self.ckpt + else: + ckpt = self.get_ckpt(ckpt_id) + if self.last_ckpt is None or self.last_ckpt != ckpt: + print(f'Initializing model from {ckpt}') + megafold_multimer = MegaFold(self.model_cfg, mixed_precision=self.arguments.mixed_precision, + device_num=self.arguments.device_num) + megafold_multimer.to_float(mstype.float16) + # print("debug network", megafold_multimer) + # fp32_white_list = (nn.Softmax, nn.LayerNorm) + # amp_convert(megafold_multimer, fp32_white_list) + # megafold_multimer = amp.auto_mixed_precision(megafold, amp_level="auto", dtype=mstype.float16) + self.model = megafold_multimer + params = load_checkpoint(ckpt) + params_infer = trans_ckpt(params) + + # print("preprocess_msa: ", params_infer['preprocess_msa.weight'].asnumpy()) + + for key in params_infer.keys(): + if "msa_row_attention_with_pair_bias.query_norm_gammas" in key: + print("debug", key) + load_param_into_net(self.model, params_infer) + return self.model + + def model_process_data(self, raw_feature): + feat = self.processed_feature.pipeline(self.model_cfg, self.data_cfg, raw_feature) + return feat + +def distance(points): + return np.sqrt(np.sum((points[:, None] - points[None, :])**2, + axis=-1)) + +def mask_mean(mask, value, eps=1e-10): + mask_shape = mask.shape + value_shape = value.shape + + axis = list(range(len(mask_shape))) + + broadcast_factor = 1. + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + + return (np.sum(mask * value, axis=tuple(axis)) / + (np.sum(mask, axis=tuple(axis)) * broadcast_factor + eps)) + + +def recycle_cond(i, prev, next_in, feat, recycle_early_stop_tolerance): + print("start recycle_cond") + + ca_idx = residue_constants.atom_order['CA'] + sq_diff = np.square(distance(prev[:, ca_idx, :].astype(np.float64)) - + distance(next_in[:, ca_idx, :].astype(np.float64))) + seq_mask_idx = 8 + mask = feat[seq_mask_idx][:, None] * feat[seq_mask_idx][None, :] + sq_diff = mask_mean(mask.astype(np.float64), sq_diff) + diff = np.sqrt(sq_diff + 1e-8) + has_exceeded_tolerance = ( + (i == 0) | bool(diff > recycle_early_stop_tolerance) + ) + print(f"recycle {i} diff: {diff}") + print("end recycle_cond: ", has_exceeded_tolerance) + # mydict = { + # 'i': i, + # 'sq_diff': sq_diff.asnumpy(), + # 'diff': diff.asnumpy(), + # 'prev': prev.asnumpy(), + # 'next_in': next_in.asnumpy(), + # 'mask': mask.asnumpy() + # } + # with open(f'/job/file/rec_{i}.pkl', 'wb') as f: + # pickle.dump(mydict, f) + return has_exceeded_tolerance, diff.item() + +def grasp_infer_quick(model_gen: ModelGenerator, ckpt_id, raw_feature: dict, restraints: dict, output_prefix, + nbdist_ca_thre=5.0, viol_thre=5.0, mask_terminal_residues=0, iter=5, max_rm_ratio=0.2, left_ratio=0.2, same_msa_across_recycle=True, + num_recycle=20, dtype=np.float16, seed=None, recycle_early_stop_tolerance=0.5): + print('Using quick inference') + ori_res_length = raw_feature['msa'].shape[1] + # run with no restraints provided + if restraints is None: + restraints = { + 'sbr': np.zeros((ori_res_length, ori_res_length, len(BINS) + 1)), + 'sbr_mask': np.zeros((ori_res_length, ori_res_length)), + 'interface_mask': np.zeros(ori_res_length), + 'asym_id': raw_feature['asym_id'] + } + + mydicts = [] + + restraints0 = restraints.copy() + + t0 = time.time() + megafold_multimer = model_gen.get_model(ckpt_id) + seq_length = model_gen.seq_length + + os.makedirs(os.path.dirname(output_prefix), exist_ok=True) + + if seed is not None: + np.random.seed(seed) + + feat_list = [] + + left_thre = (restraints['interface_mask'].sum() + restraints['sbr_mask'].sum()/2)*left_ratio + left_thre = int(np.ceil(left_thre)) + print(f'At least {left_thre} restraints will be used in the final iteration') + + # initialize prevs + prev_pos = Tensor(np.zeros([seq_length, 37, 3]).astype(dtype)) + prev_msa_first_row = Tensor(np.zeros([seq_length, 256]).astype(dtype)) + prev_pair = Tensor(np.zeros([seq_length, seq_length, 128]).astype(dtype)) + prev_prev_pos = prev_pos.asnumpy() + next_in_prev_pos = prev_pos.asnumpy() + it = 0 + num_recycle_cur_iter = 0 + max_recycle_per_iter = 4 + + for i in range(num_recycle): + + print("now its num_recycle", i) + + # pad restraints to fixed length + sbr = Tensor(np_pad(restraints['sbr'], seq_length, axis=(0, 1)).astype(dtype)) + sbr_mask = Tensor(np_pad(restraints['sbr_mask'], seq_length, axis=(0, 1)).astype(dtype)) + interface_mask = Tensor(np_pad(restraints['interface_mask'], seq_length, axis=0).astype(dtype)) + + # process data + f_i = 0 if same_msa_across_recycle else i + if len(feat_list)-1 < f_i: + feat_list.append(model_gen.model_process_data(raw_feature)) + feat = feat_list[f_i] + + # inference + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits, aligned_error_logits, aligned_error_breaks = megafold_multimer(*feat_i, + sbr, sbr_mask, interface_mask, + prev_pos, + prev_msa_first_row, + prev_pair) + prev_prev_pos = next_in_prev_pos + next_in_prev_pos = prev_pos.asnumpy() + # compute diff + has_exceeded_tolerance, diff = recycle_cond(i, prev_prev_pos, next_in_prev_pos, feat, recycle_early_stop_tolerance) + num_recycle_cur_iter += 1 + + end_cur_iter = (not has_exceeded_tolerance) or (num_recycle_cur_iter >= max_recycle_per_iter) + print(f"iter: {it+1}, recycle: {i}, diff: {diff}, has_exceeded_tolerance: {has_exceeded_tolerance}, end_cur_iter: {end_cur_iter}", flush=True) + + if end_cur_iter: + print(f"early stop: {i}, diff: {diff}, iter: {it+1} =============================") + + # extract results + final_atom_positions, predicted_lddt_logits = [i.asnumpy()[:ori_res_length] for i in (prev_pos, predicted_lddt_logits)] + final_atom_mask = feat[16][:ori_res_length] + confidence, plddt = compute_confidence(predicted_lddt_logits, return_lddt=True) + b_factors = plddt[:, None] * final_atom_mask + aligned_error_logits = aligned_error_logits.asnumpy()[:ori_res_length, :ori_res_length] + ranking_score = compute_ranking_score(aligned_error_logits, aligned_error_breaks.asnumpy(), raw_feature['asym_id']) + ranking_score = round(ranking_score*100, 2) + + unrelaxed_protein = from_prediction(final_atom_positions, + final_atom_mask, + feat[0][:ori_res_length], + feat[1][:ori_res_length], + b_factors, + feat[5][:ori_res_length] - 1, + remove_leading_feature_dimension=False) + + + + # Write sturcutres into pdb files + pdb_file = to_pdb(unrelaxed_protein) + pdb_path = f'{output_prefix}_iter{it+1}.pdb' + with open(pdb_path, 'w') as f: + f.write(pdb_file) + + # filter restraints + print(f'Filter Restraints Iteration {it+1}') + rm_num, break_num, max_nb_dist, recall, recall_conf, viol_num, max_viol_dist = filter_restraints(restraints, restraints0, unrelaxed_protein, nbdist_ca_thre=nbdist_ca_thre, max_rm_ratio=max_rm_ratio, viol_thre=viol_thre, mask_terminal_residues=mask_terminal_residues) + print(f'Filter out {rm_num} restraint(s), confidence {confidence}, 0.8iptm+0.2ptm {ranking_score}') + rest = int(restraints['interface_mask'].sum() + restraints['sbr_mask'].sum()/2) + + # record + assert rm_num >=0, rm_num + mydict = { + 'Iter': it+1, + 'Conf': round(confidence, 3), + 'RankScore': ranking_score, + 'Total': rm_num+rest, + 'Remove': rm_num, + 'Rest': rest, + 'MaxNbDist': max_nb_dist, + 'BreakNum': break_num, + 'Recall': recall, + 'RecallByConf': recall_conf, + 'Recycle_num': num_recycle_cur_iter, + 'Diff': round(diff, 3), + 'ViolNum': int(viol_num), + 'MaxViolDist': round(max_viol_dist, 2), + 'Time': round(time.time()-t0, 2) + } + + mydicts.append(mydict) + if (rest <= left_thre) or (rm_num == 0) or (it>=iter-1): + break + t0 = time.time() + it += 1 + num_recycle_cur_iter = 0 + df = pd.DataFrame(mydicts) + return df + +def grasp_infer(model_gen: ModelGenerator, ckpt_id, raw_feature: dict, restraints: dict, output_prefix, + nbdist_ca_thre=5.0, viol_thre=5.0, mask_terminal_residues=0, iter=5, max_rm_ratio=0.2, left_ratio=0.2, baseline=False, same_msa_across_recycle=True, + num_recycle=20, dtype=np.float16, seed=None, recycle_early_stop_tolerance=0.5, device_num=8): + + ori_res_length = raw_feature['msa'].shape[1] + + # run with no restraints provided + if restraints is None: + restraints = { + 'sbr': np.zeros((ori_res_length, ori_res_length, len(BINS) + 1)), + 'sbr_mask': np.zeros((ori_res_length, ori_res_length)), + 'interface_mask': np.zeros(ori_res_length), + 'asym_id': raw_feature['asym_id'] + } + + mydicts = [] + + restraints0 = restraints.copy() + + t0 = time.time() + megafold_multimer = model_gen.get_model(ckpt_id) + seq_length = model_gen.seq_length + print("seq_length: ", seq_length) + os.makedirs(os.path.dirname(output_prefix), exist_ok=True) + + if seed is not None: + np.random.seed(seed) + + feat_list = [] + + left_thre = (restraints['interface_mask'].sum() + restraints['sbr_mask'].sum()/2)*left_ratio + left_thre = int(np.ceil(left_thre)) + print(f'At least {left_thre} restraints will be used in the final iteration') + print(f"iter is {iter}") + for it in range(iter): + rank = get_rank() + step = seq_length // device_num + # print("it: ", it, "iter: ", iter) + # pad restraints to fixed length + + sbr = Tensor(np_pad(restraints['sbr'], seq_length, axis=(0, 1)).astype(dtype)) + sbr_mask = Tensor(np_pad(restraints['sbr_mask'], seq_length, axis=(0, 1)).astype(dtype)) + interface_mask = Tensor(np_pad(restraints['interface_mask'], seq_length, axis=0).astype(dtype)) + + # +++++++++++++++ Fixed ++++++++++++++++++ + sbr = Tensor(sbr[rank*step : (rank + 1)*step, :, :]) + sbr_mask = Tensor(sbr_mask[rank*step : (rank + 1)*step, :]) + interface_mask = Tensor(interface_mask[rank*step : (rank + 1)*step]) + + + # initialize prevs + prev_pos = Tensor(np.zeros([seq_length, 37, 3]).astype(dtype)) + prev_msa_first_row = Tensor(np.zeros([seq_length, 256]).astype(dtype)) + prev_pair = Tensor(np.zeros([seq_length, seq_length, 128]).astype(dtype)) + + prev_prev_pos = prev_pos.asnumpy() + next_in_prev_pos = prev_pos.asnumpy() + # # +++++++++++++++ Fixed ++++++++++++++++++ + # prev_pos = Tensor(prev_pos[rank*step : (rank + 1)*step]) + prev_msa_first_row = Tensor(prev_msa_first_row[rank*step : (rank + 1)*step, :]) + prev_pair = Tensor(prev_pair[:, rank*step : (rank + 1)*step, :]) + + first_dim = [0, 1, 5, 6, 7, 8, 10, 15, 16] + print(f"num_recycle is {num_recycle}") + for i in range(num_recycle): + f_i = 0 if same_msa_across_recycle else i + if len(feat_list)-1 < f_i: + feat_list.append(model_gen.model_process_data(raw_feature)) + + feat = feat_list[f_i] + has_exceeded_tolerance, diff = recycle_cond(i, prev_prev_pos, next_in_prev_pos, feat, recycle_early_stop_tolerance) + + + # =============== Fixed ============== + # feat_i = [Tensor(x) for x in feat] + + # +++++++++++++++ Fixed ++++++++++++++ + + feat_i = [] + for index, x in enumerate(feat): + if index in first_dim: + feat_i.append(Tensor(x[rank*step : (rank+1)*step])) + else: + feat_i.append(Tensor(x[:, rank*step : (rank + 1)*step])) + + diff = round(diff, 3) + if not has_exceeded_tolerance: + print(f"early stop: {i}") + break + + print("--------------------start----------------------") + # +++++++++++++++ Fixed ++++++++++++++++++ + prev_pos = Tensor(prev_pos[rank*step : (rank + 1)*step]) + # prev_msa_first_row = Tensor(prev_msa_first_row[rank*step : (rank + 1)*step, :]) + # prev_pair = Tensor(prev_pair[:, rank*step : (rank + 1)*step, :]) + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits, aligned_error_logits, aligned_error_breaks = megafold_multimer(*feat_i, + sbr, sbr_mask, interface_mask, + prev_pos, + prev_msa_first_row, + prev_pair) + print("--------------------end------------------------") + + prev_prev_pos = next_in_prev_pos + next_in_prev_pos = prev_pos.asnumpy() + del feat_i + gc.collect() + + + prev_pos, predicted_lddt_logits = [i.asnumpy()[:ori_res_length] for i in (prev_pos, predicted_lddt_logits)] + final_atom_positions = prev_pos + final_atom_mask = feat[16][:ori_res_length] + + confidence, plddt = compute_confidence(predicted_lddt_logits, return_lddt=True) + b_factors = plddt[:, None] * final_atom_mask + aligned_error_logits = aligned_error_logits.asnumpy()[:ori_res_length, :ori_res_length] + # ranking_score = compute_ranking_score(aligned_error_logits, aligned_error_breaks.asnumpy()[:ori_res_length, :ori_res_length], raw_feature['asym_id']) + # ranking_score = round(ranking_score*100, 2) + + unrelaxed_protein = from_prediction(final_atom_positions, + final_atom_mask, + feat[0][:ori_res_length], + feat[1][:ori_res_length], + b_factors, + feat[5][:ori_res_length] - 1, + remove_leading_feature_dimension=False) + + # Write sturcutres into pdb files + pdb_file = to_pdb(unrelaxed_protein) + # pdb_path = f'{output_prefix}_score{ranking_score}_iter{it+1}.pdb' + pdb_path = f'{output_prefix}_iter{it+1}_recycle{num_recycle}_graph_parallel.pdb' + print(" ===================== pdb_path ==================== ", pdb_path) + with open(pdb_path, 'w') as f: + f.write(pdb_file) + + # filter restraints + print(f'Filter Restraints Iteration {it+1} =============================================') + rm_num, break_num, max_nb_dist, recall, recall_conf, viol_num, max_viol_dist = filter_restraints(restraints, restraints0, unrelaxed_protein, nbdist_ca_thre=nbdist_ca_thre, max_rm_ratio=max_rm_ratio, viol_thre=viol_thre, mask_terminal_residues=mask_terminal_residues) + # print(f'Filter out {rm_num} restraint(s), confidence {confidence}, 0.8iptm+0.2ptm {ranking_score}') + rest = int(restraints['interface_mask'].sum() + restraints['sbr_mask'].sum()/2) + + + # record + assert rm_num >=0, rm_num + + tys = [] + if ((confidence < 50) and (i == num_recycle-1) and (break_num > 20)): + tys.append('Failed') + if (recall is not None) and (recall < 0.01): + tys.append('LowRecall') + if (rest <= left_thre): + tys.append('RemoveThre') + if (rm_num == 0): + tys.append('Converged') + if (it == iter-1): + tys.append('LastIter') + + if len(tys) == 0: + ty = 'Continue' + else: + ty = ','.join(tys) + + # mydict = { + # 'Iter': it+1, + # 'Conf': round(confidence, 3), + # 'RankScore': ranking_score, + # 'Total': rm_num+rest, + # 'Remove': rm_num, + # 'Rest': rest, + # 'MaxNbDist': max_nb_dist, + # 'BreakNum': break_num, + # 'Recall': None if recall is None else round(recall, 2), + # 'RecallByConf': None if recall_conf is None else round(recall_conf, 3), + # 'Recycle_num': i+1, + # 'Diff': round(diff, 3), + # 'ViolNum': int(viol_num), + # 'MaxViolDist': None if max_viol_dist is None else round(max_viol_dist, 2), + # 'Time': round(time.time()-t0, 2), + # 'Type': ty + # } + + # mydicts.append(mydict) + # t0 = time.time() + if len(tys)>0: + print('Stop iteration:', ty, flush=True) + break + # df = pd.DataFrame(mydicts) + # return df + return + +def infer_batch(model_gen: ModelGenerator,data_gen: DataGenerator, sn: SplitNamelist, pdb_ids, ckpt_ids, res_dir, num_seed=5, + baseline=False, nbdist_ca_thre=5.0, viol_thre=5.0, mask_terminal_residues=2, iter=5, + num_recycle=20, recycle_early_stop_tolerance=0.5, + check_tsv_exist=True, quick=False): + + os.makedirs(res_dir, exist_ok=True) + + ori_pdb_id_num = len(pdb_ids) + pdb_ids = [i for i in pdb_ids if data_gen.get_len(i)<=model_gen.seq_length] + ckpt_ids = [i for i in ckpt_ids if os.path.isfile(model_gen.get_ckpt(i))] + ckpt_ids.sort() + print(f'Total pdbs: {len(pdb_ids)}, with {ori_pdb_id_num - len(pdb_ids)} pdb_ids removed because of length exceeding {model_gen.seq_length}') + print(f'Total ckpts: {len(ckpt_ids)}, {ckpt_ids}') + + print("res_dir", res_dir) + if check_tsv_exist: + all_cases = [(ckpt_id, pdb_id) for ckpt_id in ckpt_ids for pdb_id in pdb_ids if len(glob.glob(f'{res_dir}/ckpt_{ckpt_id}_{pdb_id}*_info.tsv')) model_gen.seq_length: + print(f'length out of range {pdb_id}: sequence length {raw_feature["aatype"].shape[0]} > {model_gen.seq_length}') + continue + + t2 = time.time() + if quick: + df = grasp_infer_quick(model_gen, ckpt_id, raw_feature, restraints, output_prefix, iter=iter, nbdist_ca_thre=nbdist_ca_thre, viol_thre=viol_thre, mask_terminal_residues=mask_terminal_residues, seed=seed, num_recycle=num_recycle, + recycle_early_stop_tolerance=recycle_early_stop_tolerance) + else: + df = grasp_infer(model_gen, ckpt_id, raw_feature, restraints, output_prefix, iter=iter, nbdist_ca_thre=nbdist_ca_thre, viol_thre=viol_thre, mask_terminal_residues=mask_terminal_residues, baseline=baseline, seed=seed, + num_recycle=num_recycle, recycle_early_stop_tolerance=recycle_early_stop_tolerance) + df.to_csv(infofile, sep='\t', index=False) + t3 = time.time() + timings = f"[{datetime.datetime.now()}] ckpt step_{ckpt_id} prot_name {pdb_id} seed {seed}, pre_process_time {round(t2 - t1, 2)}, predict time {round(t3 - t2, 2)} , all_time {round(t3 - t1, 2)}" + print(df.to_string()) + print(timings) + +def infer_config(rotate_split, outdir, key=None): + # context.set_context(mode=context.PYNATIVE, + # device_target="Ascend", + # mempool_block_size="31GB", + # max_call_depth=6000) + + os.environ["OPENBLAS_NUM_THREADS"] = "1" + os.environ["NUMEXPR_NUM_THREADS"] = "1" + os.environ["VECLIB_MAXIMUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + os.environ["OMP_NUM_THREADS"] = "1" + + rank_id = int(os.getenv('RANK_ID', '0')) + device_id = int(os.getenv("DEVICE_ID", '0')) + rank_size = int(os.getenv('RANK_SIZE', '1')) + + print('{}, rank id: {}, device id: {}, device num: {}, start to run...'.format( + datetime.datetime.now(), rank_id, device_id, rank_size), flush=True) + sn = SplitNamelist(rank_id, rank_size, outdir, rotate_split=rotate_split, key=key) + return sn \ No newline at end of file diff --git a/MindSPONGE/applications/research/Grasp/utils_xyh.py b/MindSPONGE/applications/research/Grasp/utils_xyh.py new file mode 100644 index 0000000000000000000000000000000000000000..b5fba2e80a6a2c9bdb8e8c15516d9b37f5996ea3 --- /dev/null +++ b/MindSPONGE/applications/research/Grasp/utils_xyh.py @@ -0,0 +1,52 @@ +import numpy as np +import pandas as pd +import pprint + +def show_npdict(npdict, tag=None): + '''print Dict elegantly''' + if tag: + print('*'*80) + print(f'*{tag:^78}*') + print('*'*80) + print('\n') + + for k in sorted(list(npdict.keys())): + v = npdict[k] + if isinstance(v, np.ndarray): + print(f'{f"{k}: {v.shape}, {v.dtype}":-<80}') + if len(v.shape) == 0: + print(v) + continue + v1 = v.copy() + while len(v1.shape) > 1 and v1.shape[0] > 0: + v1 = v1[0] + print(v1[:min(10, len(v1))]) + else: + print(f'{f"{k}, {type(v)}":-<80}') + pprint.pprint(v) + print('') + +def reduce_dim(x, num): + if not isinstance(x, np.ndarray): + x = x.asnumpy() + while len(x.shape) > num: + x = x[0] + return x + +def print_restraint_info(d1): + '''print sampled restraints' information''' + d = d1.copy() + contact_mask_input = reduce_dim(d["contact_mask_input"], 2) + contact_mask_output = reduce_dim(d["contact_mask_output"], 2) + true_contact = contact_mask_input * contact_mask_output + false_contact = contact_mask_input * (1 - contact_mask_output) + asym_id = reduce_dim(d['asym_id'], 1) + is_intra = (asym_id[None] == asym_id[:, None]) + true_inter = (true_contact * (1 - is_intra)).sum() / 2 + true_intra = (true_contact * is_intra).sum() / 2 + false_inter = (false_contact * (1 - is_intra)).sum() / 2 + false_intra = (false_contact * is_intra).sum() / 2 + df = pd.DataFrame([[true_inter, true_intra], [false_inter, false_intra]], columns=['inter', 'intra'], index=['true', 'false']) + df['sum'] = df.sum(1) + df.loc['sum'] = df.sum(0) + print(df)