diff --git a/MindSPONGE/examples/polypeptide/README_CN.md b/MindSPONGE/examples/polypeptide/README_CN.md
index 7ad3a42ba843682de6dfa600c1972d2adc5fbbaa..51998c66a25264c384a04e07c9e9ff6e2a14ca31 100644
--- a/MindSPONGE/examples/polypeptide/README_CN.md
+++ b/MindSPONGE/examples/polypeptide/README_CN.md
@@ -16,8 +16,14 @@
- [运行结果](#运行结果)
- [性能描述](#性能描述)
+ - [MindSpore.Numpy方式运行SPONGE](#以MindSpore.Numpy方式运行SPONGE)
+ - [MindSPONGE-Numpy运行机制](#MindSPONGE-Numpy运行机制)
+ - [CUDA核函数与MindSpore的映射及迁移](#CUDA核函数与MindSpore的映射及迁移)
+ - [使用图算融合/算子自动生成进行加速](#使用图算融合/算子自动生成进行加速)
+ - [性能描述](#性能描述)
+
-
+
## 概述
@@ -207,3 +213,117 @@ _steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ _ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _1
| Speed | 5.0 ms/step
| Total time | 5.7 s
| Script | [Link](https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/mindsponge/scripts)
+
+## MindSpore.Numpy方式运行SPONGE
+
+除了以Cuda核函数执行的方式运行SPONGE之外,现在我们同时支持以MindSpore原生表达的方式运行SPONGE。计算能量,坐标和力的Cuda核函数均被替换成了Numpy的语法表达,同时拥有MindSpore强大的加速能力。
+
+Sponge-Numpy现在同样支持丙氨酸三肽水溶液体系,如果需要运行Sponge-Numpy,可以使用如下命令:
+
+```shell
+python main_numpy.py --i /path/NVT_290_10ns.in \
+ --amber_parm /path/WATER_ALA.parm7 \
+ --c /path/WATER_ALA_350_cool_290.rst7 \
+ --o /path/ala_NVT_290_10ns.out
+```
+
+或者直接在 MindSPONGE / examples / polypeptide / scripts 目录下执行:
+
+```shell
+bash run_numpy.sh
+```
+
+其余步骤均与[丙氨酸三肽水溶液](https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/examples/polypeptide/case_polypeptide.md)保持一致。
+
+## MindSPONGE-Numpy运行机制
+
+为了更充分地利用MindSpore的强大特性,以及更好地展示分子动力学算法的运作机制, SPONGE中的Cuda核函数被重构为Numpy语法的脚本,并封装在[md/functions](https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/mindsponge/md/functions)模块之中。
+
+MindSpore.Numpy计算套件包含一套完整的符合Numpy规范的接口,使得开发者可以Numpy的原生语法表达MindSpore的模型,同时拥有MindSpore的加速能力。MindSpore.Numpy是建立在MindSpore基础算子(mindspore.ops)之上的一层封装,以MindSpore张量为计算单元,因此可以与其他MindSpore特性完全兼容。更多介绍请参考[这里](https://www.mindspore.cn/docs/programming_guide/en/master/numpy.html)。
+
+### CUDA核函数与MindSpore的映射及迁移
+
+在丙氨酸三肽水溶液案例中,所有的Cuda核函数均完成了MindSpore改写并完成了精度验证。对于Cuda算法的Numpy迁移,现在提供一个计算LJ Energy的案例供参考:
+
+```cuda
+for (int j = threadIdx.y; j < N; j = j + blockDim.y) {
+ atom_j = nl_i.atom_serial[j];
+ r2 = uint_crd[atom_j];
+
+ int_x = r2.uint_x - r1.uint_x;
+ int_y = r2.uint_y - r1.uint_y;
+ int_z = r2.uint_z - r1.uint_z;
+ dr.x = boxlength[0].x * int_x;
+ dr.y = boxlength[0].y * int_y;
+ dr.z = boxlength[0].z * int_z;
+
+ dr2 = dr.x * dr.x + dr.y * dr.y + dr.z * dr.z;
+ if (dr2 < cutoff_square) {
+ dr_2 = 1. / dr2;
+ dr_4 = dr_2 * dr_2;
+ dr_6 = dr_4 * dr_2;
+
+ y = (r2.lj_type - r1.lj_type);
+ x = y >> 31;
+ y = (y ^ x) - x;
+ x = r2.lj_type + r1.lj_type;
+ r2.lj_type = (x + y) >> 1;
+ x = (x - y) >> 1;
+ atom_pair_lj_type = (r2.lj_type * (r2.lj_type + 1) >> 1) + x;
+
+ dr_2 = (0.083333333 * lj_type_A[atom_pair_lj_type] * dr_6 - 0.166666666 * lj_type_B[atom_pair_lj_type]) * dr_6;
+ ene_lin = ene_lin + dr_2;
+ }
+ }
+ atomicAdd(&lj_ene[atom_i], ene_lin);
+```
+
+以上代码首先计算了当前分子与其邻居的距离,对于距离小于`cutoff_square`的分子对,进行后续的能量计算,并且累加到当前分子之上,作为该分子累积能量的一部分。因此,Mindore.Numpy版本的迁移分为两部分:
+
+- 理解Cuda核函数算法
+- 进行Numpy拆分以及映射
+
+重构之后的Numpy脚本如下:
+
+```python
+nl_atom_serial_crd = uint_crd[nl_atom_serial]
+r2_lj_type = atom_lj_type[nl_atom_serial]
+crd_expand = np.expand_dims(uint_crd, 1)
+crd_d = get_periodic_displacement(nl_atom_serial_crd, crd_expand, scaler)
+crd_2 = crd_d ** 2
+crd_2 = np.sum(crd_2, -1)
+nl_atom_mask = get_neighbour_index(atom_numbers, nl_atom_serial.shape[1])
+mask = np.logical_and((crd_2 < cutoff_square), (nl_atom_mask < np.expand_dims(nl_atom_numbers, -1)))
+dr_2 = 1. / crd_2
+dr_6 = np.power(dr_2, 3.)
+r1_lj_type = np.expand_dims(atom_lj_type, -1)
+x = r2_lj_type + r1_lj_type
+y = np.absolute(r2_lj_type - r1_lj_type)
+r2_lj_type = (x + y) // 2
+x = (x - y) // 2
+atom_pair_lj_type = (r2_lj_type * (r2_lj_type + 1) // 2) + x
+dr_2 = (0.083333333 * lj_A[atom_pair_lj_type] * dr_6 - 0.166666666 * lj_B[atom_pair_lj_type]) * dr_6
+ene_lin = np.where(mask, dr_2, zero_tensor)
+ene_lin = np.sum(ene_lin, -1)
+return ene_lin
+```
+
+具体步骤如下:
+
+- 将Cuda中的索引取值改写为Numpy的fancy index索引取值。
+- 建立一个掩码矩阵,将所有距离大于`cutoff_square`的计算屏蔽。
+- 将所有元素级的运算变换为可以广播的Numpy的矩阵计算,中间可能涉及矩阵的形状变换。
+
+### 使用图算融合/算子自动生成进行加速
+
+为了获得成倍的加速收益,MindSPONGE-Numpy默认开启[图算融合](https://www.mindspore.cn/docs/programming_guide/en/master/enable_graph_kernel_fusion.html)以及[自动算子生成] (https://gitee.com/mindspore/akg)。这两个加速组件可以为模型提供3倍(甚至更多)的性能提升, 使得MindSPONGE-Numpy达到与原版本性能相近的程度.
+
+在模型脚本中添加如下两行代码即可获得图算融合加速:
+(examples/polypeptide/src/main_numpy.py):
+
+```python
+# Enable Graph Mode, with GPU as backend, and allow Graph Kernel Fusion
+context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, enable_graph_kernel=True)
+# Make fusion rules for specific operators
+context.set_context(graph_kernel_flags="--enable_expand_ops=Gather --enable_cluster_ops=TensorScatterAdd --enable_recompute_fusion=false")
+```
diff --git a/MindSPONGE/examples/polypeptide/scripts/run_numpy.sh b/MindSPONGE/examples/polypeptide/scripts/run_numpy.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f8226a7f1b9d88de69ef07d8d9fc2dce822d9d19
--- /dev/null
+++ b/MindSPONGE/examples/polypeptide/scripts/run_numpy.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+python ../src/main_numpy.py --i ../data/NVT_290_10ns.in --amber_parm ../data/WATER_ALA.parm7 --c ../data/WATER_ALA_350_cool_290.rst7
diff --git a/MindSPONGE/examples/polypeptide/src/main_numpy.py b/MindSPONGE/examples/polypeptide/src/main_numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc763aa253a05ff87f544fddccd9858ee0e010fb
--- /dev/null
+++ b/MindSPONGE/examples/polypeptide/src/main_numpy.py
@@ -0,0 +1,65 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''main'''
+import argparse
+import time
+
+import mindspore.context as context
+
+from simulation_np import Simulation
+
+parser = argparse.ArgumentParser(description='SPONGE Controller')
+parser.add_argument('--i', type=str, default=None, help='Input file')
+parser.add_argument('--amber_parm', type=str, default=None, help='Paramter file in AMBER type')
+parser.add_argument('--c', type=str, default=None, help='Initial coordinates file')
+parser.add_argument('--r', type=str, default="restrt", help='')
+parser.add_argument('--x', type=str, default="mdcrd", help='')
+parser.add_argument('--o', type=str, default="mdout", help='Output file')
+parser.add_argument('--box', type=str, default="mdbox", help='')
+parser.add_argument('--device_id', type=int, default=0, help='GPU device id')
+parser.add_argument('--u', type=bool, default=False, help='If use mdnn to update the atom charge')
+parser.add_argument('--checkpoint', type=str, default="", help='Checkpoint file')
+args_opt = parser.parse_args()
+
+context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id,
+ save_graphs=False, enable_graph_kernel=True)
+context.set_context(graph_kernel_flags="--enable_expand_ops=Gather \
+ --enable_cluster_ops=TensorScatterAdd,UnSortedSegmentSum,GatherNd \
+ --enable_recompute_fusion=false --enable_parallel_fusion=true")
+
+if __name__ == "__main__":
+ simulation = Simulation(args_opt)
+
+ start = time.time()
+ compiler_time = 0
+ save_path = args_opt.o
+ simulation.main_initial()
+ for steps in range(simulation.md_info.step_limit):
+ print_step = steps % simulation.ntwx
+ if steps == simulation.md_info.step_limit - 1:
+ print_step = 0
+ temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
+ nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene, _ = simulation()
+
+ if steps == 0:
+ compiler_time = time.time()
+ if steps % simulation.ntwx == 0 or steps == simulation.md_info.step_limit - 1:
+ simulation.main_print(steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene,
+ sigma_of_dihedral_ene, nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene)
+
+ end = time.time()
+ print("Main time(s):", end - start)
+ print("Main time(s) without compiler:", end - compiler_time)
+ simulation.main_destroy()
diff --git a/MindSPONGE/examples/polypeptide/src/simulation_np.py b/MindSPONGE/examples/polypeptide/src/simulation_np.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef13323732d4a3a109beba4f148c12cd2126efc2
--- /dev/null
+++ b/MindSPONGE/examples/polypeptide/src/simulation_np.py
@@ -0,0 +1,449 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''Simulation'''
+import numpy as np
+
+import mindspore.common.dtype as mstype
+from mindspore import Tensor
+from mindspore import nn
+from mindspore.common.parameter import Parameter
+from mindspore.ops import functional as F
+from mindspore.ops import operations as P
+from mindsponge import Angle
+from mindsponge import Bond
+from mindsponge import Dihedral
+from mindsponge import LangevinLiujian
+from mindsponge import LennardJonesInformation
+from mindsponge import MdInformation
+from mindsponge import NonBond14
+from mindsponge import NeighborList
+from mindsponge import ParticleMeshEwald
+
+from mindsponge.md.functions import angle_energy, angle_force_with_atom_energy, bond_force_with_atom_energy, \
+ crd_to_uint_crd, dihedral_14_ljcf_force_with_atom_energy, lj_energy, \
+ md_iteration_leap_frog_liujian, pme_excluded_force, lj_force_pme_direct_force, \
+ dihedral_force_with_atom_energy, bond_energy, nb14_lj_energy, \
+ nb14_cf_energy, md_temperature, dihedral_energy, pme_energy, \
+ pme_reciprocal_force, reform_excluded_list, get_pme_bc
+
+
+class Controller:
+ '''controller'''
+
+ def __init__(self, args_opt):
+ self.input_file = args_opt.i
+ self.initial_coordinates_file = args_opt.c
+ self.amber_parm = args_opt.amber_parm
+ self.restrt = args_opt.r
+ self.mdcrd = args_opt.x
+ self.mdout = args_opt.o
+ self.mdbox = args_opt.box
+
+ self.command_set = {}
+ self.md_task = None
+ self.commands_from_in_file()
+
+ def commands_from_in_file(self):
+ '''command from in file'''
+ file = open(self.input_file, 'r')
+ context = file.readlines()
+ file.close()
+ self.md_task = context[0].strip()
+ for val in context:
+ if "=" in val:
+ assert len(val.strip().split("=")) == 2
+ flag, value = val.strip().split("=")
+ value = value.replace(",", '')
+ flag = flag.replace(" ", "")
+ if flag not in self.command_set:
+ self.command_set[flag] = value
+ else:
+ print("ERROR COMMAND FILE")
+
+
+class Simulation(nn.Cell):
+ '''simulation'''
+
+ def __init__(self, args_opt):
+ super(Simulation, self).__init__()
+ self.control = Controller(args_opt)
+ self.md_info = MdInformation(self.control)
+ self.bond = Bond(self.control)
+ self.angle = Angle(self.control)
+ self.dihedral = Dihedral(self.control)
+ self.nb14 = NonBond14(self.control, self.dihedral, self.md_info.atom_numbers)
+ self.nb_info = NeighborList(self.control, self.md_info.atom_numbers, self.md_info.box_length)
+ self.lj_info = LennardJonesInformation(self.control, self.md_info.nb.cutoff, self.md_info.sys.box_length)
+ self.liujian_info = LangevinLiujian(self.control, self.md_info.atom_numbers)
+ self.pme_method = ParticleMeshEwald(self.control, self.md_info)
+ self.bond_energy_sum = Tensor(0, mstype.int32)
+ self.angle_energy_sum = Tensor(0, mstype.int32)
+ self.dihedral_energy_sum = Tensor(0, mstype.int32)
+ self.nb14_lj_energy_sum = Tensor(0, mstype.int32)
+ self.nb14_cf_energy_sum = Tensor(0, mstype.int32)
+ self.lj_energy_sum = Tensor(0, mstype.int32)
+ self.ee_ene = Tensor(0, mstype.int32)
+ self.total_energy = Tensor(0, mstype.int32)
+ # Init scalar
+ self.ntwx = self.md_info.ntwx
+ self.atom_numbers = self.md_info.atom_numbers
+ self.residue_numbers = self.md_info.residue_numbers
+ self.bond_numbers = self.bond.bond_numbers
+ self.angle_numbers = self.angle.angle_numbers
+ self.dihedral_numbers = self.dihedral.dihedral_numbers
+ self.nb14_numbers = self.nb14.nb14_numbers
+ self.nxy = self.nb_info.nxy
+ self.grid_numbers = self.nb_info.grid_numbers
+ self.max_atom_in_grid_numbers = self.nb_info.max_atom_in_grid_numbers
+ self.max_neighbor_numbers = self.nb_info.max_neighbor_numbers
+ self.excluded_atom_numbers = self.md_info.nb.excluded_atom_numbers
+ self.refresh_count = Parameter(Tensor(self.nb_info.refresh_count, mstype.int32), requires_grad=False)
+ self.refresh_interval = self.nb_info.refresh_interval
+ self.skin = self.nb_info.skin
+ self.cutoff = self.nb_info.cutoff
+ self.cutoff_square = self.nb_info.cutoff_square
+ self.cutoff_with_skin = self.nb_info.cutoff_with_skin
+ self.half_cutoff_with_skin = self.nb_info.half_cutoff_with_skin
+ self.cutoff_with_skin_square = self.nb_info.cutoff_with_skin_square
+ self.half_skin_square = self.nb_info.half_skin_square
+ self.beta = self.pme_method.beta
+ self.fftx = self.pme_method.fftx
+ self.ffty = self.pme_method.ffty
+ self.fftz = self.pme_method.fftz
+ self.box_length_0 = self.md_info.box_length[0]
+ self.box_length_1 = self.md_info.box_length[1]
+ self.box_length_2 = self.md_info.box_length[2]
+ self.random_seed = self.liujian_info.random_seed
+ self.dt = self.liujian_info.dt
+ self.half_dt = self.liujian_info.half_dt
+ self.exp_gamma = self.liujian_info.exp_gamma
+ self.init_tensor()
+ self.op_define()
+ self.update = False
+
+ def init_tensor(self):
+ '''init tensor'''
+ self.crd = Parameter(
+ Tensor(np.float32(np.asarray(self.md_info.coordinate).reshape([self.atom_numbers, 3])), mstype.float32),
+ requires_grad=False)
+ self.crd_to_uint_crd_cof = Tensor(np.asarray(self.md_info.pbc.crd_to_uint_crd_cof, np.float32), mstype.float32)
+ self.uint_dr_to_dr_cof = Parameter(
+ Tensor(np.asarray(self.md_info.pbc.uint_dr_to_dr_cof, np.float32), mstype.float32), requires_grad=False)
+ self.box_length = Tensor(self.md_info.box_length, mstype.float32)
+ self.charge = Parameter(Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32),
+ requires_grad=False)
+ self.old_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
+ requires_grad=False)
+ self.last_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
+ requires_grad=False)
+ self.uint_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32),
+ requires_grad=False)
+ self.mass_inverse = Tensor(self.md_info.h_mass_inverse, mstype.float32)
+ self.res_start = Tensor(self.md_info.h_res_start, mstype.int32)
+ self.res_end = Tensor(self.md_info.h_res_end, mstype.int32)
+ self.mass = Tensor(self.md_info.h_mass, mstype.float32)
+ self.velocity = Parameter(Tensor(self.md_info.velocity, mstype.float32), requires_grad=False)
+ self.acc = Parameter(Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32), requires_grad=False)
+ self.bond_atom_a = Tensor(np.asarray(self.bond.h_atom_a, np.int32), mstype.int32)
+ self.bond_atom_b = Tensor(np.asarray(self.bond.h_atom_b, np.int32), mstype.int32)
+ self.bond_k = Tensor(np.asarray(self.bond.h_k, np.float32), mstype.float32)
+ self.bond_r0 = Tensor(np.asarray(self.bond.h_r0, np.float32), mstype.float32)
+ self.angle_atom_a = Tensor(np.asarray(self.angle.h_atom_a, np.int32), mstype.int32)
+ self.angle_atom_b = Tensor(np.asarray(self.angle.h_atom_b, np.int32), mstype.int32)
+ self.angle_atom_c = Tensor(np.asarray(self.angle.h_atom_c, np.int32), mstype.int32)
+ self.angle_k = Tensor(np.asarray(self.angle.h_angle_k, np.float32), mstype.float32)
+ self.angle_theta0 = Tensor(np.asarray(self.angle.h_angle_theta0, np.float32), mstype.float32)
+ self.dihedral_atom_a = Tensor(np.asarray(self.dihedral.h_atom_a, np.int32), mstype.int32)
+ self.dihedral_atom_b = Tensor(np.asarray(self.dihedral.h_atom_b, np.int32), mstype.int32)
+ self.dihedral_atom_c = Tensor(np.asarray(self.dihedral.h_atom_c, np.int32), mstype.int32)
+ self.dihedral_atom_d = Tensor(np.asarray(self.dihedral.h_atom_d, np.int32), mstype.int32)
+ self.pk = Tensor(np.asarray(self.dihedral.h_pk, np.float32), mstype.float32)
+ self.gamc = Tensor(np.asarray(self.dihedral.h_gamc, np.float32), mstype.float32)
+ self.gams = Tensor(np.asarray(self.dihedral.h_gams, np.float32), mstype.float32)
+ self.pn = Tensor(np.asarray(self.dihedral.h_pn, np.float32), mstype.float32)
+ self.ipn = Tensor(np.asarray(self.dihedral.h_ipn, np.int32), mstype.int32)
+ self.nb14_atom_a = Tensor(np.asarray(self.nb14.h_atom_a, np.int32), mstype.int32)
+ self.nb14_atom_b = Tensor(np.asarray(self.nb14.h_atom_b, np.int32), mstype.int32)
+ self.lj_scale_factor = Tensor(np.asarray(self.nb14.h_lj_scale_factor, np.float32), mstype.float32)
+ self.cf_scale_factor = Tensor(np.asarray(self.nb14.h_cf_scale_factor, np.float32), mstype.float32)
+ self.grid_n = Tensor(self.nb_info.grid_n, mstype.int32)
+ self.grid_length_inverse = Tensor(self.nb_info.grid_length_inverse, mstype.float32)
+ self.bucket = Parameter(Tensor(
+ np.asarray(self.nb_info.bucket, np.int32).reshape([self.grid_numbers, self.max_atom_in_grid_numbers]),
+ mstype.int32), requires_grad=False)
+ self.atom_numbers_in_grid_bucket = Parameter(Tensor(self.nb_info.atom_numbers_in_grid_bucket, mstype.int32),
+ requires_grad=False)
+ self.atom_in_grid_serial = Parameter(Tensor(np.zeros([self.nb_info.atom_numbers,], np.int32), mstype.int32),
+ requires_grad=False)
+ self.pointer = Parameter(
+ Tensor(np.asarray(self.nb_info.pointer, np.int32).reshape([self.grid_numbers, 125]), mstype.int32),
+ requires_grad=False)
+ self.nl_atom_numbers = Parameter(Tensor(np.zeros([self.atom_numbers,], np.int32), mstype.int32),
+ requires_grad=False)
+ self.nl_atom_serial = Parameter(
+ Tensor(np.zeros([self.atom_numbers, self.max_neighbor_numbers], np.int32), mstype.int32),
+ requires_grad=False)
+ excluded_list_start = np.asarray(self.md_info.nb.h_excluded_list_start, np.int32)
+ excluded_list = np.asarray(self.md_info.nb.h_excluded_list, np.int32)
+ excluded_numbers = np.asarray(self.md_info.nb.h_excluded_numbers, np.int32)
+ self.excluded_list_start = Tensor(excluded_list_start)
+ self.excluded_list = Tensor(excluded_list)
+ self.excluded_numbers = Tensor(excluded_numbers)
+ self.excluded_matrix = Tensor(reform_excluded_list(excluded_list, excluded_list_start, excluded_numbers))
+ box = (self.box_length_0, self.box_length_1, self.box_length_2)
+ self.pme_bc = Tensor(get_pme_bc(self.fftx, self.ffty, self.fftz, box, self.beta), mstype.float32)
+ self.need_refresh_flag = Tensor(np.asarray([0], np.int32), mstype.int32)
+ self.atom_lj_type = Tensor(self.lj_info.atom_lj_type, mstype.int32)
+ self.lj_a = Tensor(self.lj_info.h_lj_a, mstype.float32)
+ self.lj_b = Tensor(self.lj_info.h_lj_b, mstype.float32)
+ self.sqrt_mass = Tensor(self.liujian_info.h_sqrt_mass, mstype.float32)
+ self.rand_state = Parameter(Tensor(self.liujian_info.rand_state, mstype.float32))
+ self.zero_fp_tensor = Tensor(np.asarray([0,], np.float32))
+
+ def op_define(self):
+ '''op define'''
+ self.neighbor_list_update_init = P.NeighborListUpdate(grid_numbers=self.grid_numbers,
+ atom_numbers=self.atom_numbers, not_first_time=0,
+ nxy=self.nxy,
+ excluded_atom_numbers=self.excluded_atom_numbers,
+ cutoff_square=self.cutoff_square,
+ half_skin_square=self.half_skin_square,
+ cutoff_with_skin=self.cutoff_with_skin,
+ half_cutoff_with_skin=self.half_cutoff_with_skin,
+ cutoff_with_skin_square=self.cutoff_with_skin_square,
+ refresh_interval=self.refresh_interval,
+ cutoff=self.cutoff, skin=self.skin,
+ max_atom_in_grid_numbers=self.max_atom_in_grid_numbers,
+ max_neighbor_numbers=self.max_neighbor_numbers)
+ self.neighbor_list_update = P.NeighborListUpdate(grid_numbers=self.grid_numbers, atom_numbers=self.atom_numbers,
+ not_first_time=1, nxy=self.nxy,
+ excluded_atom_numbers=self.excluded_atom_numbers,
+ cutoff_square=self.cutoff_square,
+ half_skin_square=self.half_skin_square,
+ cutoff_with_skin=self.cutoff_with_skin,
+ half_cutoff_with_skin=self.half_cutoff_with_skin,
+ cutoff_with_skin_square=self.cutoff_with_skin_square,
+ refresh_interval=self.refresh_interval, cutoff=self.cutoff,
+ skin=self.skin,
+ max_atom_in_grid_numbers=self.max_atom_in_grid_numbers,
+ max_neighbor_numbers=self.max_neighbor_numbers)
+
+ def simulation_beforce_caculate_force(self):
+ '''simulation before calculate force'''
+ crd_to_uint_crd_cof = 0.5 * self.crd_to_uint_crd_cof
+ uint_crd = crd_to_uint_crd(crd_to_uint_crd_cof, self.crd)
+ return uint_crd
+
+ def simulation_caculate_force(self, uint_crd, scaler, nl_atom_numbers, nl_atom_serial):
+ '''simulation calculate force'''
+ bond_f, _ = bond_force_with_atom_energy(self.atom_numbers, self.bond_numbers,
+ uint_crd, scaler, self.bond_atom_a,
+ self.bond_atom_b, self.bond_k,
+ self.bond_r0)
+
+ angle_f, _ = angle_force_with_atom_energy(self.angle_numbers, uint_crd,
+ scaler, self.angle_atom_a,
+ self.angle_atom_b, self.angle_atom_c,
+ self.angle_k, self.angle_theta0)
+
+ dihedral_f, _ = dihedral_force_with_atom_energy(self.dihedral_numbers,
+ uint_crd, scaler,
+ self.dihedral_atom_a,
+ self.dihedral_atom_b,
+ self.dihedral_atom_c,
+ self.dihedral_atom_d, self.ipn,
+ self.pk, self.gamc, self.gams,
+ self.pn)
+
+ nb14_f, _ = dihedral_14_ljcf_force_with_atom_energy(self.atom_numbers, uint_crd,
+ self.atom_lj_type, self.charge,
+ scaler, self.nb14_atom_a, self.nb14_atom_b,
+ self.lj_scale_factor, self.cf_scale_factor,
+ self.lj_a, self.lj_b)
+
+ lj_f = lj_force_pme_direct_force(self.atom_numbers, self.cutoff_square, self.beta,
+ uint_crd, self.atom_lj_type, self.charge, scaler,
+ nl_atom_numbers, nl_atom_serial, self.lj_a, self.lj_b)
+
+ pme_excluded_f = pme_excluded_force(self.atom_numbers, self.beta, uint_crd, scaler,
+ self.charge, self.excluded_matrix)
+
+ pme_reciprocal_f = pme_reciprocal_force(self.atom_numbers, self.fftx, self.ffty,
+ self.fftz, self.box_length_0, self.box_length_1,
+ self.box_length_2, self.pme_bc, uint_crd, self.charge)
+ force = bond_f + angle_f + dihedral_f + nb14_f + lj_f + pme_excluded_f + pme_reciprocal_f
+ return force
+
+ def simulation_caculate_energy(self, uint_crd, uint_dr_to_dr_cof):
+ '''simulation calculate energy'''
+ bond_e = bond_energy(self.atom_numbers, self.bond_numbers, uint_crd, uint_dr_to_dr_cof,
+ self.bond_atom_a, self.bond_atom_b, self.bond_k, self.bond_r0)
+ bond_energy_sum = bond_e.sum(keepdims=True)
+
+ angle_e = angle_energy(self.angle_numbers, uint_crd, uint_dr_to_dr_cof,
+ self.angle_atom_a, self.angle_atom_b, self.angle_atom_c,
+ self.angle_k, self.angle_theta0)
+ angle_energy_sum = angle_e.sum(keepdims=True)
+
+ dihedral_e = dihedral_energy(self.dihedral_numbers, uint_crd, uint_dr_to_dr_cof,
+ self.dihedral_atom_a, self.dihedral_atom_b,
+ self.dihedral_atom_c, self.dihedral_atom_d,
+ self.ipn, self.pk, self.gamc, self.gams,
+ self.pn)
+ dihedral_energy_sum = dihedral_e.sum(keepdims=True)
+
+ nb14_lj_e = nb14_lj_energy(self.nb14_numbers, self.atom_numbers, uint_crd,
+ self.atom_lj_type, self.charge, uint_dr_to_dr_cof,
+ self.nb14_atom_a, self.nb14_atom_b, self.lj_scale_factor, self.lj_a,
+ self.lj_b)
+ nb14_cf_e = nb14_cf_energy(self.nb14_numbers, self.atom_numbers, uint_crd,
+ self.atom_lj_type, self.charge, uint_dr_to_dr_cof,
+ self.nb14_atom_a, self.nb14_atom_b, self.cf_scale_factor)
+ nb14_lj_energy_sum = nb14_lj_e.sum(keepdims=True)
+ nb14_cf_energy_sum = nb14_cf_e.sum(keepdims=True)
+
+ lj_e = lj_energy(self.atom_numbers, self.cutoff_square, uint_crd,
+ self.atom_lj_type, self.charge, uint_dr_to_dr_cof,
+ self.nl_atom_numbers, self.nl_atom_serial, self.lj_a,
+ self.lj_b)
+ lj_energy_sum = lj_e.sum(keepdims=True)
+ reciprocal_e, self_e, direct_e, correction_e = pme_energy(self.atom_numbers,
+ self.beta, self.fftx,
+ self.ffty, self.fftz,
+ self.pme_bc,
+ uint_crd, self.charge,
+ self.nl_atom_numbers,
+ self.nl_atom_serial,
+ uint_dr_to_dr_cof,
+ self.excluded_matrix)
+
+ ee_ene = reciprocal_e + self_e + direct_e + correction_e
+ total_energy = bond_energy_sum + angle_energy_sum + dihedral_energy_sum + \
+ nb14_lj_energy_sum + nb14_cf_energy_sum + lj_energy_sum + ee_ene
+ return bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, nb14_cf_energy_sum, \
+ lj_energy_sum, ee_ene, total_energy
+
+ def simulation_temperature(self):
+ '''caculate temperature'''
+ res_ek_energy = md_temperature(self.residue_numbers, self.atom_numbers, self.res_start,
+ self.res_end, self.velocity, self.mass)
+ temperature = res_ek_energy.sum()
+ return temperature
+
+ def simulation_md_iteration_leap_frog_liujian(self, inverse_mass, sqrt_mass_inverse, crd, frc):
+ '''simulation leap frog iteration liujian'''
+ return md_iteration_leap_frog_liujian(self.atom_numbers, self.half_dt,
+ self.dt, self.exp_gamma, inverse_mass,
+ sqrt_mass_inverse, self.velocity,
+ crd, frc, self.acc)
+
+ def main_print(self, *args):
+ """compute the temperature"""
+ steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
+ nb14_lj_energy_sum, nb14_cf_energy_sum, lj_energy_sum, ee_ene = list(args)
+ if steps == 0:
+ print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
+ "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
+
+ temperature = temperature.asnumpy()
+ total_potential_energy = total_potential_energy.asnumpy()
+ print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)), end=" ")
+ if self.bond.bond_numbers > 0:
+ sigma_of_bond_ene = sigma_of_bond_ene.asnumpy()
+ print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
+ if self.angle.angle_numbers > 0:
+ sigma_of_angle_ene = sigma_of_angle_ene.asnumpy()
+ print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
+ if self.dihedral.dihedral_numbers > 0:
+ sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy()
+ print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
+ if self.nb14.nb14_numbers > 0:
+ nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy()
+ nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy()
+ print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
+ lj_energy_sum = lj_energy_sum.asnumpy()
+ ee_ene = ee_ene.asnumpy()
+ print("{:>7.3f}".format(float(lj_energy_sum)), end=" ")
+ print("{:>12.3f}".format(float(ee_ene)))
+ if self.file is not None:
+ self.file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
+ " {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
+ float(sigma_of_bond_ene), float(sigma_of_angle_ene),
+ float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
+ float(nb14_cf_energy_sum), float(lj_energy_sum), float(ee_ene)))
+ if self.datfile is not None:
+ self.datfile.write(self.crd.asnumpy())
+
+ def main_initial(self):
+ """main initial"""
+ if self.control.mdout:
+ self.file = open(self.control.mdout, 'w')
+ self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
+ "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
+ if self.control.mdcrd:
+ self.datfile = open(self.control.mdcrd, 'wb')
+
+ def main_destroy(self):
+ """main destroy"""
+ if self.file is not None:
+ self.file.close()
+ print("Save .out file successfully!")
+ if self.datfile is not None:
+ self.datfile.close()
+ print("Save .dat file successfully!")
+
+ def construct(self):
+ '''construct'''
+ self.last_crd = self.crd
+ res = self.neighbor_list_update(self.atom_numbers_in_grid_bucket,
+ self.bucket,
+ self.crd,
+ self.box_length,
+ self.grid_n,
+ self.grid_length_inverse,
+ self.atom_in_grid_serial,
+ self.old_crd,
+ self.crd_to_uint_crd_cof,
+ self.uint_crd,
+ self.pointer,
+ self.nl_atom_numbers,
+ self.nl_atom_serial,
+ self.uint_dr_to_dr_cof,
+ self.excluded_list_start,
+ self.excluded_list,
+ self.excluded_numbers,
+ self.need_refresh_flag,
+ self.refresh_count)
+ self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
+ self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
+ self.uint_dr_to_dr_cof = F.depend(self.uint_dr_to_dr_cof, res)
+ self.old_crd = F.depend(self.old_crd, res)
+ self.atom_numbers_in_grid_bucket = F.depend(self.atom_numbers_in_grid_bucket, res)
+ self.bucket = F.depend(self.bucket, res)
+ self.atom_in_grid_serial = F.depend(self.atom_in_grid_serial, res)
+ self.pointer = F.depend(self.pointer, res)
+ uint_crd = self.simulation_beforce_caculate_force()
+ force = self.simulation_caculate_force(uint_crd, self.uint_dr_to_dr_cof, self.nl_atom_numbers,
+ self.nl_atom_serial)
+
+ bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, nb14_cf_energy_sum, \
+ lj_energy_sum, ee_ene, total_energy = self.simulation_caculate_energy(uint_crd, self.uint_dr_to_dr_cof)
+
+ temperature = self.simulation_temperature()
+ self.velocity = F.depend(self.velocity, temperature)
+ self.velocity, self.crd, _ = self.simulation_md_iteration_leap_frog_liujian(self.mass_inverse,
+ self.sqrt_mass, self.crd, force)
+ return temperature, total_energy, bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, \
+ nb14_cf_energy_sum, lj_energy_sum, ee_ene, res
diff --git a/MindSPONGE/mindsponge/md/functions/lj_force_pme_direct_force.py b/MindSPONGE/mindsponge/md/functions/lj_force_pme_direct_force.py
index 260f3cb042c0264258b9c4ee2d3a778ab17cb807..ef7acd4d55775323e433c74d1e91b3522affaaa6 100644
--- a/MindSPONGE/mindsponge/md/functions/lj_force_pme_direct_force.py
+++ b/MindSPONGE/mindsponge/md/functions/lj_force_pme_direct_force.py
@@ -21,28 +21,28 @@ from .common import get_neighbour_index, get_periodic_displacement, get_zero_ten
TWO_DIVIDED_BY_SQRT_PI = 1.1283791670218446
MAX_NUMBER_OF_NEIGHBOR = 800
-def lj_force_pme_direct_force(atom_numbers, cutoff, pme_beta, uint_crd, LJtype, charge,
- scalar, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
+def lj_force_pme_direct_force(atom_numbers, cutoff, pme_beta, uint_crd, lj_type, charge,
+ scalar, nl_numbers, nl_serial, d_lj_a, d_lj_b):
"""
Calculate the Lennard-Jones force and PME direct force together.
The calculation formula of Lennard-Jones part is the same as operator
- LJForce(), and the PME direct part is within PME method.
+ ljForce(), and the PME direct part is within PME method.
Agrs:
atom_numbers(int): the number of atoms, N.
cutoff(float): the square value of cutoff.
pme_beta(float): PME beta parameter, same as operator PMEReciprocalForce().
uint_crd (Tensor, uint32): [N, 3], the unsigned int coordinate value of each atom.
- LJtype (Tensor, int32): [N,], the Lennard-Jones type of each atom.
+ lj_type (Tensor, int32): [N,], the Lennard-Jones type of each atom.
charge (Tensor, float32): [N,], the charge carried by each atom.
scaler (Tensor, float32): [3,], the scale factor between real
space coordinate and its unsigned int value.
nl_numbers (Tensor, int32): [N,], the each atom.
nl_serial (Tensor, int32): [N, 800], the neighbor list of each atom, the max number is 800.
- d_LJ_A (Tensor, float32): [Q,], the Lennard-Jones A coefficient of each kind of atom pair.
+ d_lj_a (Tensor, float32): [Q,], the Lennard-Jones A coefficient of each kind of atom pair.
Q is the number of atom pair.
- d_LJ_B (Tensor, float32): [Q,], the Lennard-Jones B coefficient of each kind of atom pair.
+ d_lj_b (Tensor, float32): [Q,], the Lennard-Jones B coefficient of each kind of atom pair.
Q is the number of atom pair.
Outputs:
@@ -51,8 +51,8 @@ def lj_force_pme_direct_force(atom_numbers, cutoff, pme_beta, uint_crd, LJtype,
Supported Platforms:
``GPU``
"""
- N = uint_crd.shape[0]
- frc = get_zero_tensor((N, 3), np.float32)
+ n = uint_crd.shape[0]
+ frc = get_zero_tensor((n, 3), np.float32)
r1 = np.tile(np.expand_dims(uint_crd, 1), (1, MAX_NUMBER_OF_NEIGHBOR, 1))
r2 = uint_crd[nl_serial]
@@ -67,15 +67,15 @@ def lj_force_pme_direct_force(atom_numbers, cutoff, pme_beta, uint_crd, LJtype,
dr_8 = dr_4 * dr_4
dr_6 = dr_4 * dr_2
- r1_LJ_type = np.expand_dims(LJtype, -1)
- r2_LJ_type = LJtype[nl_serial]
- x = r2_LJ_type + r1_LJ_type
- y = np.absolute(r2_LJ_type - r1_LJ_type)
- r2_LJ_type = (x + y) // 2
+ r1_lj_type = np.expand_dims(lj_type, -1)
+ r2_lj_type = lj_type[nl_serial]
+ x = r2_lj_type + r1_lj_type
+ y = np.absolute(r2_lj_type - r1_lj_type)
+ r2_lj_type = (x + y) // 2
x = (x - y) // 2
- atom_pair_LJ_type = (r2_LJ_type * (r2_LJ_type + 1) // 2) + x
+ atom_pair_lj_type = (r2_lj_type * (r2_lj_type + 1) // 2) + x
- frc_abs = (-d_LJ_A[atom_pair_LJ_type] * dr_6 + d_LJ_B[atom_pair_LJ_type]) * dr_8
+ frc_abs = (-d_lj_a[atom_pair_lj_type] * dr_6 + d_lj_b[atom_pair_lj_type]) * dr_8
beta_dr = pme_beta * dr_abs
frc_cf_abs = beta_dr * TWO_DIVIDED_BY_SQRT_PI * np.exp(-beta_dr * beta_dr) + ops.Erfc()(beta_dr)
frc_cf_abs *= dr_2 * dr_1
@@ -89,7 +89,7 @@ def lj_force_pme_direct_force(atom_numbers, cutoff, pme_beta, uint_crd, LJtype,
# apply cutoff mask
frc_lin = np.where(np.expand_dims(mask, -1), frc_lin, 0)
frc_record = np.sum(frc_lin, -2)
-
+ nl_serial = np.where(nl_atom_mask >= np.expand_dims(nl_numbers, -1), -1, nl_serial)
frc = ops.tensor_scatter_add(frc, np.expand_dims(nl_serial, -1), -frc_lin)
frc += frc_record
return frc
diff --git a/MindSPONGE/mindsponge/md/functions/pme_common.py b/MindSPONGE/mindsponge/md/functions/pme_common.py
index d5a693ccfaf881850bcd8362712db717117502ca..4681e65f6deeefa74ccb0a5f720b02313050b6da 100644
--- a/MindSPONGE/mindsponge/md/functions/pme_common.py
+++ b/MindSPONGE/mindsponge/md/functions/pme_common.py
@@ -21,29 +21,32 @@ from mindspore.ops import constexpr
from .common import get_neighbour_index, get_periodic_displacement
PERIODIC_FACTOR_INVERSE = 2.32830643e-10
-PME_Ma = mnp.array([1.0 / 6.0, -0.5, 0.5, -1.0 / 6.0])
-PME_Mb = mnp.array([0, 0.5, -1, 0.5])
-PME_Mc = mnp.array([0, 0.5, 0, -0.5])
-PME_Md = mnp.array([0, 1.0 / 6.0, 4.0 / 6.0, 1.0 / 6.0])
+pme_ma = mnp.array([1.0 / 6.0, -0.5, 0.5, -1.0 / 6.0])
+pme_mb = mnp.array([0, 0.5, -1, 0.5])
+pme_mc = mnp.array([0, 0.5, 0, -0.5])
+pme_md = mnp.array([0, 1.0 / 6.0, 4.0 / 6.0, 1.0 / 6.0])
fft3d = ops.FFT3D()
ifft3d = ops.IFFT3D()
+real = ops.Real()
+conj = ops.Conj()
@constexpr
-def to_tensor(args, dtype=mnp.float32):
+def to_tensor(args):
return mnp.array(args)
-def Scale_List(element_numbers, tensor, scaler):
+def scale_list(element_numbers, tensor, scaler):
"""Scale values in `tensor`."""
if tensor.ndim > 0 and len(tensor) > element_numbers:
tensor = tensor[:element_numbers]
return tensor * scaler
-def PME_Atom_Near(uint_crd, PME_atom_near, PME_Nin, periodic_factor_inverse_x,
- periodic_factor_inverse_y, periodic_factor_inverse_z, atom_numbers,
- fftx, ffty, fftz, PME_kxyz, PME_uxyz):
+# pylint: disable=too-many-function-args
+def pme_a_near(uint_crd, pme_atom_near, pme_nin, periodic_factor_inverse_x,
+ periodic_factor_inverse_y, periodic_factor_inverse_z, atom_numbers,
+ fftx, ffty, fftz, pme_kxyz, pme_uxyz):
'''pme atom near'''
periodic_factor_inverse_xyz = to_tensor(
(periodic_factor_inverse_x, periodic_factor_inverse_y, periodic_factor_inverse_z))
@@ -51,46 +54,47 @@ def PME_Atom_Near(uint_crd, PME_atom_near, PME_Nin, periodic_factor_inverse_x,
tempf = uint_crd.astype('float32') * periodic_factor_inverse_xyz
tempu = tempf.astype('int32')
tempu = ops.depend(tempu, tempu)
- PME_frxyz = tempf - tempu
+ pme_frxyz = tempf - tempu
- cond = mnp.not_equal(PME_uxyz.astype(mnp.int32), tempu).any(1, True)
- PME_uxyz = mnp.where(cond, tempu, PME_uxyz)
+ cond = mnp.not_equal(pme_uxyz.astype(mnp.int32), tempu).any(1, True)
+ pme_uxyz = mnp.where(cond, tempu, pme_uxyz)
tempu = tempu.reshape(atom_numbers, 1, 3)
- kxyz = tempu - PME_kxyz.astype(mnp.int32)
+ kxyz = tempu - pme_kxyz.astype(mnp.int32)
kxyz_plus = kxyz + mnp.array([fftx, ffty, fftz])
kxyz = ops.select(kxyz < 0, kxyz_plus, kxyz)
- kxyz = kxyz * to_tensor((PME_Nin, fftz, 1), mnp.int32).reshape(1, 1, 3)
+ kxyz = kxyz * to_tensor((pme_nin, fftz, 1)).reshape(1, 1, 3)
temp_near = mnp.sum(kxyz.astype(mnp.float32), -1).astype(mnp.int32)
- PME_atom_near = mnp.where(cond, temp_near, PME_atom_near)
+ pme_atom_near = mnp.where(cond, temp_near, pme_atom_near)
- return PME_frxyz, PME_uxyz, PME_atom_near
+ return pme_frxyz, pme_uxyz, pme_atom_near
-def PME_Q_Spread(PME_atom_near, charge, PME_frxyz, PME_Q, PME_kxyz, atom_numbers):
+def pme_q_spread(pme_atom_near, charge, pme_frxyz, pme_q, pme_kxyz, atom_numbers):
'''pme q spread'''
- PME_kxyz = PME_kxyz.astype(mnp.int32)
- pme_ma = PME_Ma[PME_kxyz]
- pme_mb = PME_Mb[PME_kxyz]
- pme_mc = PME_Mc[PME_kxyz]
- pme_md = PME_Md[PME_kxyz]
+ pme_kxyz = pme_kxyz.astype(mnp.int32)
+ pme_ma_new = pme_ma[pme_kxyz]
+ pme_mb_new = pme_mb[pme_kxyz]
+ pme_mc_new = pme_mc[pme_kxyz]
+ pme_md_new = pme_md[pme_kxyz]
- tempf = PME_frxyz.reshape(atom_numbers, 1, 3) # (N, 1, 3)
+ tempf = pme_frxyz.reshape(atom_numbers, 1, 3) # (N, 1, 3)
tempf2 = tempf * tempf # (N, 1, 3)
temp_charge = charge.reshape(atom_numbers, 1) # (N, 1)
- tempf = pme_ma * tempf * tempf2 + pme_mb * tempf2 + pme_mc * tempf + pme_md # (N, 64, 3)
+ tempf = pme_ma_new * tempf * tempf2 + pme_mb_new * tempf2 + pme_mc_new * tempf + pme_md_new # (N, 64, 3)
- tempQ = temp_charge * tempf[..., 0] * tempf[..., 1] * tempf[..., 2] # (N, 64)
- index = PME_atom_near.ravel() # (N * 64,)
- tempQ = tempQ.ravel() # (N * 64,)
- PME_Q = ops.tensor_scatter_add(PME_Q, mnp.expand_dims(index, -1), tempQ)
+ tempq = temp_charge * tempf[..., 0] * tempf[..., 1] * tempf[..., 2] # (N, 64)
+ index = pme_atom_near.ravel() # (N * 64,)
+ tempq = tempq.ravel() # (N * 64,)
+ pme_q = ops.tensor_scatter_add(pme_q, mnp.expand_dims(index, -1), tempq)
- return PME_Q
+ return pme_q
-def PME_Direct_Energy(atom_numbers, nl_numbers, nl_serial, uint_crd, boxlength, charge, beta, cutoff_square):
+# pylint: disable=too-many-arguments
+def pme_direct_energy(atom_numbers, nl_numbers, nl_serial, uint_crd, boxlength, charge, beta, cutoff_square):
'''pme direct energy'''
r2 = uint_crd[nl_serial]
@@ -123,15 +127,15 @@ def get_pme_kxyz():
return pme_kxyz
-def PME_Energy_Reciprocal(real, imag, BC):
- return mnp.sum((real * real + imag * imag) * BC)
+def pme_energy_reciprocal(pme_fq, bc):
+ return mnp.sum(real(conj(pme_fq) * pme_fq) * bc)
-def PME_Energy_Product(tensor1, tensor2):
+def pme_energy_product(tensor1, tensor2):
return mnp.sum(tensor1 * tensor2)
-def PME_Excluded_Energy_Correction(atom_numbers, uint_crd, scaler, charge, pme_beta, sqrt_pi, excluded_matrix):
+def pme_excluded_energy_correction(uint_crd, scaler, charge, pme_beta, excluded_matrix):
'''pme excluded energy correction'''
mask = (excluded_matrix > -1)
# (N, 3)[N, M]-> (N, M, 3)
diff --git a/MindSPONGE/mindsponge/md/functions/pme_energy.py b/MindSPONGE/mindsponge/md/functions/pme_energy.py
index 2320e8cceddb1b49aa01a3a90d614ec387509a9f..65121d7506f17ba82828da74c91b34893f5088a9 100644
--- a/MindSPONGE/mindsponge/md/functions/pme_energy.py
+++ b/MindSPONGE/mindsponge/md/functions/pme_energy.py
@@ -15,12 +15,11 @@
'''pme energy'''
import mindspore.numpy as mnp
from .common import PI, get_zero_tensor, get_full_tensor
-from .pme_common import Scale_List, PME_Atom_Near, PME_Q_Spread, PME_Direct_Energy, PME_Energy_Reciprocal, \
- PME_Excluded_Energy_Correction, PME_Energy_Product, get_pme_kxyz, PERIODIC_FACTOR_INVERSE, \
+from .pme_common import scale_list, pme_a_near, pme_q_spread, pme_direct_energy, pme_energy_reciprocal, \
+ pme_excluded_energy_correction, pme_energy_product, get_pme_kxyz, PERIODIC_FACTOR_INVERSE, \
fft3d
cutoff = 10.0
-
def pme_energy(atom_numbers, beta, fftx, ffty, fftz, pme_bc, uint_crd, charge,
nl_numbers, nl_serial, scaler, excluded_matrix):
"""
@@ -55,35 +54,34 @@ def pme_energy(atom_numbers, beta, fftx, ffty, fftz, pme_bc, uint_crd, charge,
Supported Platforms:
``GPU``
"""
- PME_Nin = ffty * fftz
- PME_Nall = fftx * ffty * fftz
+ pme_nin = ffty * fftz
+ pme_nall = fftx * ffty * fftz
- PME_kxyz = get_pme_kxyz() # (64, 3)
+ pme_kxyz = get_pme_kxyz() # (64, 3)
- PME_uxyz = get_full_tensor((atom_numbers, 3), 2 ** 30, mnp.uint32)
- PME_atom_near = get_zero_tensor((atom_numbers, 64), mnp.int32)
- PME_frxyz, PME_uxyz, PME_atom_near = PME_Atom_Near(uint_crd, PME_atom_near, PME_Nin,
- PERIODIC_FACTOR_INVERSE * fftx,
- PERIODIC_FACTOR_INVERSE * ffty,
- PERIODIC_FACTOR_INVERSE * fftz, atom_numbers,
- fftx, ffty, fftz, PME_kxyz, PME_uxyz)
+ pme_uxyz = get_full_tensor((atom_numbers, 3), 2 ** 30, mnp.uint32)
+ pme_atom_near = get_zero_tensor((atom_numbers, 64), mnp.int32)
+ pme_frxyz, pme_uxyz, pme_atom_near = pme_a_near(uint_crd, pme_atom_near, pme_nin,
+ PERIODIC_FACTOR_INVERSE * fftx,
+ PERIODIC_FACTOR_INVERSE * ffty,
+ PERIODIC_FACTOR_INVERSE * fftz, atom_numbers,
+ fftx, ffty, fftz, pme_kxyz, pme_uxyz)
- PME_Q = get_full_tensor(PME_Nall, 0, mnp.float32)
- PME_Q = PME_Q_Spread(PME_atom_near, charge, PME_frxyz, PME_Q, PME_kxyz, atom_numbers)
+ pme_q = get_full_tensor(pme_nall, 0, mnp.float32)
+ pme_q = pme_q_spread(pme_atom_near, charge, pme_frxyz, pme_q, pme_kxyz, atom_numbers)
- PME_Q = PME_Q.reshape(fftx, ffty, fftz).astype('float32')
- real, imag = fft3d(PME_Q)
+ pme_q = pme_q.reshape(fftx, ffty, fftz).astype('float32')
+ pme_fq = fft3d(pme_q)
- reciprocal_ene = PME_Energy_Reciprocal(real.ravel(), imag.ravel(), pme_bc)
+ reciprocal_ene = pme_energy_reciprocal(pme_fq, pme_bc.reshape((fftx, ffty, fftz // 2 + 1)))
- self_ene = PME_Energy_Product(charge, charge)
- self_ene = Scale_List(1, self_ene, -beta / mnp.sqrt(PI))
+ self_ene = pme_energy_product(charge, charge)
+ self_ene = scale_list(1, self_ene, -beta / mnp.sqrt(PI))
- direct_ene = PME_Direct_Energy(atom_numbers, nl_numbers, nl_serial, uint_crd, scaler, charge, beta,
+ direct_ene = pme_direct_energy(atom_numbers, nl_numbers, nl_serial, uint_crd, scaler, charge, beta,
cutoff * cutoff)
- correction_ene = PME_Excluded_Energy_Correction(atom_numbers, uint_crd, scaler, charge, beta, mnp.sqrt(PI),
- excluded_matrix)
+ correction_ene = pme_excluded_energy_correction(uint_crd, scaler, charge, beta, excluded_matrix)
return mnp.atleast_1d(reciprocal_ene), mnp.atleast_1d(self_ene), \
mnp.atleast_1d(direct_ene), mnp.atleast_1d(correction_ene)
diff --git a/MindSPONGE/mindsponge/md/functions/pme_reciprocal_force.py b/MindSPONGE/mindsponge/md/functions/pme_reciprocal_force.py
index 3e7d6cce050dd3a572ef8259dc52081978a7086f..4191c78ba2aeb55fc4bb1cbeb4c2745251984e6c 100644
--- a/MindSPONGE/mindsponge/md/functions/pme_reciprocal_force.py
+++ b/MindSPONGE/mindsponge/md/functions/pme_reciprocal_force.py
@@ -14,44 +14,44 @@
# ============================================================================
'''pme reciprocal force'''
from mindspore import numpy as np
-
+from mindspore import dtype as mstype
from .common import get_full_tensor
-from .pme_common import get_pme_kxyz, PME_Atom_Near, PME_Q_Spread, to_tensor, \
-PERIODIC_FACTOR_INVERSE, PME_Ma, PME_Mb, PME_Mc, PME_Md, fft3d, ifft3d
+from .pme_common import get_pme_kxyz, pme_a_near, pme_q_spread, to_tensor, \
+PERIODIC_FACTOR_INVERSE, pme_ma, pme_mb, pme_mc, pme_md, fft3d, ifft3d
MAXINT = 1073741824
-PME_dMa = np.array([0.5, -1.5, 1.5, -0.5], np.float32)
-PME_dMb = np.array([0, 1, -2, 1], np.float32)
-PME_dMc = np.array([0, 0.5, 0, -0.5], np.float32)
+pme_dma = np.array([0.5, -1.5, 1.5, -0.5], np.float32)
+pme_dmb = np.array([0, 1, -2, 1], np.float32)
+pme_dmc = np.array([0, 0.5, 0, -0.5], np.float32)
-def pme_final(pme_atom_near, charge, pme_Q, pme_frxyz, pme_kxyz, pme_inverse_box_vector, atom_numbers):
+def pme_final(pme_atom_near, charge, pme_q, pme_frxyz, pme_kxyz, pme_inverse_box_vector):
'''pme final'''
- dQf = -pme_Q[pme_atom_near] * np.expand_dims(charge, -1) # N * 64
+ dqf = -pme_q[pme_atom_near] * np.expand_dims(charge, -1) # N * 64
fxyz = np.expand_dims(pme_frxyz, -2)
fxyz_2 = fxyz ** 2
# N * 64 * 3
pme_kxyz = pme_kxyz.astype(np.int32)
- xyz = (PME_Ma[pme_kxyz] * fxyz * fxyz_2 + PME_Mb[pme_kxyz] * fxyz_2 +
- PME_Mc[pme_kxyz] * fxyz + PME_Md[pme_kxyz])
- dxyz = PME_dMa[pme_kxyz] * fxyz_2 + PME_dMb[pme_kxyz] * fxyz + PME_dMc[pme_kxyz]
- Qxyz = dxyz * pme_inverse_box_vector
+ xyz = (pme_ma[pme_kxyz] * fxyz * fxyz_2 + pme_mb[pme_kxyz] * fxyz_2 +
+ pme_mc[pme_kxyz] * fxyz + pme_md[pme_kxyz])
+ dxyz = pme_dma[pme_kxyz] * fxyz_2 + pme_dmb[pme_kxyz] * fxyz + pme_dmc[pme_kxyz]
+ qxyz = dxyz * pme_inverse_box_vector
x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2]
- Qx, Qy, Qz = Qxyz[..., 0], Qxyz[..., 1], Qxyz[..., 2]
+ qx, qy, qz = qxyz[..., 0], qxyz[..., 1], qxyz[..., 2]
- Qx *= y * z * dQf
- Qy *= x * z * dQf
- Qz *= x * y * dQf
+ qx *= y * z * dqf
+ qy *= x * z * dqf
+ qz *= x * y * dqf
- force = np.stack((Qx, Qy, Qz), axis=-1)
+ force = np.stack((qx, qy, qz), axis=-1)
return np.sum(force, axis=1)
-def pme_reciprocal_force(atom_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1,
+def pme_reciprocal_force(atom_numbers, fftx, ffty, fftz, box_length_0, box_length_1,
box_length_2, pme_bc, uint_crd, charge):
"""
Calculate the reciprocal part of long-range Coulumb force using
@@ -63,8 +63,6 @@ def pme_reciprocal_force(atom_numbers, beta, fftx, ffty, fftz, box_length_0, box
Args:
atom_numbers (int): the number of atoms, N.
- beta (float): the PME beta parameter, determined by the
- non-bond cutoff value and simulation precision tolerance.
fftx (int): the number of points for Fourier transform in dimension X.
ffty (int): the number of points for Fourier transform in dimension Y.
fftz (int): the number of points for Fourier transform in dimension Z.
@@ -80,27 +78,25 @@ def pme_reciprocal_force(atom_numbers, beta, fftx, ffty, fftz, box_length_0, box
Supported Platforms:
``GPU``
"""
- pme_Nall = fftx * ffty * fftz
- pme_Nin = ffty * fftz
+ pme_nall = fftx * ffty * fftz
+ pme_nin = ffty * fftz
pme_atom_near = get_full_tensor((atom_numbers, 64), 0, np.int32)
pme_uxyz = get_full_tensor((atom_numbers, 3), MAXINT, np.uint32)
pme_kxyz = get_pme_kxyz()
pme_inverse_box_vector = to_tensor((fftx / box_length_0, ffty / box_length_1, fftz / box_length_2))
- pme_frxyz, pme_uxyz, pme_atom_near = PME_Atom_Near(
- uint_crd, pme_atom_near, pme_Nin, PERIODIC_FACTOR_INVERSE * fftx, PERIODIC_FACTOR_INVERSE * ffty,
+ pme_frxyz, pme_uxyz, pme_atom_near = pme_a_near(
+ uint_crd, pme_atom_near, pme_nin, PERIODIC_FACTOR_INVERSE * fftx, PERIODIC_FACTOR_INVERSE * ffty,
PERIODIC_FACTOR_INVERSE * fftz, atom_numbers, fftx, ffty, fftz, pme_kxyz, pme_uxyz)
- pme_Q = get_full_tensor(pme_Nall, 0, np.float32)
- pme_Q = PME_Q_Spread(pme_atom_near, charge, pme_frxyz, pme_Q, pme_kxyz, atom_numbers)
+ pme_q = get_full_tensor(pme_nall, 0, np.float32)
+ pme_q = pme_q_spread(pme_atom_near, charge, pme_frxyz, pme_q, pme_kxyz, atom_numbers)
- pme_Q = pme_Q.reshape(fftx, ffty, fftz)
+ pme_q = pme_q.reshape(fftx, ffty, fftz)
pme_bc = pme_bc.reshape(fftx, ffty, fftz // 2 + 1)
- pme_FQ_real, pme_FQ_imag = fft3d(pme_Q)
- pme_FQ_real *= pme_bc
- pme_FQ_imag *= pme_bc
- pme_Q = ifft3d(pme_FQ_real, pme_FQ_imag)
- pme_Q = pme_Q.ravel()
-
- return pme_final(
- pme_atom_near, charge, pme_Q, pme_frxyz, pme_kxyz, pme_inverse_box_vector, atom_numbers)
+ pme_fq = fft3d(pme_q)
+ pme_fq *= pme_bc.astype(mstype.complex64)
+ pme_q = ifft3d(pme_fq)
+ pme_q = pme_q.ravel()
+
+ return pme_final(pme_atom_near, charge, pme_q, pme_frxyz, pme_kxyz, pme_inverse_box_vector)