diff --git a/MindSPONGE/examples/claisen_rearrangement/src/main.py b/MindSPONGE/examples/claisen_rearrangement/src/main.py new file mode 100644 index 0000000000000000000000000000000000000000..13d77717b421f94f59b921f010ee82543c2712c4 --- /dev/null +++ b/MindSPONGE/examples/claisen_rearrangement/src/main.py @@ -0,0 +1,117 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''main''' +import argparse +import time +import sys + +from mindspore import context +from mindspore import Tensor +from mindspore import load_checkpoint + +USE_SYS_PATH = True + +if USE_SYS_PATH: + sys.path.append('..') + sys.path.append('../../../mindsponge/md') + from simulation_cybertron import SimulationCybertron + from cybertron.mdnn import Mdnn, TransCrdToCV + from cybertron.models import MolCT + from cybertron.readouts import AtomwiseReadout + from cybertron.cybertron import Cybertron + +parser = argparse.ArgumentParser(description='SPONGE Controller') +parser.add_argument('--i', type=str, default='md.in', help='Input file') +parser.add_argument('--amber_parm', type=str, default='cba.prmtop', help='Paramter file in AMBER type') +parser.add_argument('--c', type=str, default='cba_its_mw0_trans.rst7', help='Initial coordinates file') +parser.add_argument('--r', type=str, default=None, help='') +parser.add_argument('--x', type=str, default="mdcrd", help='') +parser.add_argument('--o', type=str, default="mdout.txt", help='Output file') +parser.add_argument('--box', type=str, default="mdbox.txt", help='') +parser.add_argument('--device_id', type=int, default=1, help='GPU device id') +parser.add_argument('--u', type=bool, default=False, help='If use mdnn to update the atom charge') +parser.add_argument('--checkpoint', type=str, default="", help='Checkpoint file') +parser.add_argument('--datfile', type=str, default="crd_record.dat", help='Store the evolution path in a dat file.') +parser.add_argument('--initial_coordinates_file', default=None, type=str, help='Initial rst7 pos file.') +parser.add_argument('--meta', type=bool, default=0, help='Set to 1 if MetaDynamics is used.') +parser.add_argument('--with_box', type=bool, default=1, help='Set to be 1 if periodic map is needed.') +parser.add_argument('--np_iter', type=bool, default=0, help='Set to be 1 if you want to use msnp.') + +if __name__ == '__main__': + + args_opt = parser.parse_args() + args_opt.initial_coordinates_file = args_opt.c + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False) + atom_types = Tensor([6, 1, 6, 1, 6, 1, 1, 6, 1, 6, 1, 6, 1, 1, 8]) + num_atom = atom_types.size + mod = MolCT( + min_rbf_dis=0.1, + max_rbf_dis=10, + num_rbf=128, + rbf_sigma=0.2, + n_interactions=3, + dim_feature=128, + n_heads=8, + max_cycles=1, + use_time_embedding=True, + fixed_cycles=True, + self_dis=0.1, + unit_length='A', + use_feed_forward=False, + ) + scales = 3.0 + + readout = AtomwiseReadout(n_in=mod.dim_feature, n_interactions=mod.n_interactions, activation=mod.activation, + n_out=1, mol_scale=scales, unit_energy='kcal/mol') + net = Cybertron(mod, atom_types=atom_types, full_connect=True, readout=readout, unit_dis='A', + unit_energy='kcal/mol') + + param_file = 'cba_kcal_mol_A_MolCT-best.ckpt' + load_checkpoint(param_file, net=net) + + simulation = SimulationCybertron(args_opt, network=net) + if args_opt.u and args_opt.checkpoint: + net = Mdnn() + load_checkpoint(args_opt.checkpoint) + transcrd = TransCrdToCV(simulation) + + start = time.time() + compiler_time = 0 + save_path = args_opt.o + simulation.main_initial() + for steps in range(simulation.md_info.step_limit): + print_step = steps % simulation.ntwx + if steps == simulation.md_info.step_limit - 1: + print_step = 0 + temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \ + nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene, _ = simulation(Tensor(steps), Tensor(print_step)) + + if steps == 0: + compiler_time = time.time() + if steps % simulation.ntwx == 0 or steps == simulation.md_info.step_limit - 1: + simulation.main_print(steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, + Tensor(0), Tensor(0), nb14_cf_energy_sum, LJ_energy_sum, ee_ene) + + if args_opt.u and args_opt.checkpoint and steps % (4 * simulation.ntwx) == 0: + print("Update charge!") + inputs = transcrd(Tensor(simulation.crd), Tensor(simulation.last_crd)) + t_charge = net(inputs) + simulation.charge = transcrd.updatecharge(t_charge) + + end = time.time() + print("Main time(s):", end - start) + print("Main time(s) without compiler:", end - compiler_time) + simulation.main_destroy() diff --git a/MindSPONGE/examples/claisen_rearrangement/src/simulation_cybertron.py b/MindSPONGE/examples/claisen_rearrangement/src/simulation_cybertron.py new file mode 100644 index 0000000000000000000000000000000000000000..2467fd800b10ce7c9e474053808ad8dab1f47fc3 --- /dev/null +++ b/MindSPONGE/examples/claisen_rearrangement/src/simulation_cybertron.py @@ -0,0 +1,462 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''Simulation''' +import sys +import numpy as np + +import mindspore.numpy as msnp +import mindspore.common.dtype as mstype +from mindspore import ops +from mindspore import Tensor +from mindspore import nn +from mindspore.common.parameter import Parameter +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.ops import constexpr +from mindspore.ops import composite as C + +from potential.angle import Angle +from potential.bond import Bond +from potential.dihedral import Dihedral +from potential.lennard_jones import LennardJonesInformation +from potential.nb14 import NonBond14 +from potential.particle_mesh_ewald import ParticleMeshEwald +from control.langevin_liujian_md import LangevinLiujian +from space.md_information import MdInformation +from partition.neighbor_list import NeighborList +from cybertron.meta_dynamics import Bias +from cybertron.units import units + +sys.path.append('../../../mindsponge/md') + +standard_normal = ops.StandardNormal() +zeros = ops.Zeros() + +WALL_P = 9e08 +WALL_POTENTIAL = np.zeros(200, dtype=np.float32) +WALL_POTENTIAL[0] = WALL_P +WALL_POTENTIAL[1] = WALL_P +WALL_POTENTIAL[2] = WALL_P +WALL_POTENTIAL[-1] = WALL_P +WALL_POTENTIAL[-2] = WALL_P +WALL_POTENTIAL[-3] = WALL_P +SMIN = 0 +SMAX = 8 +DS = 0.04 +OMEGA = 50 +SIGMA = 0.005 +DDT = 0.001 +T = 300 +ALPHA = 0.5 +GAMMA = 6 +KAPPA = 4 +UPPER_BOUND_INDEX = 190 +LOWER_BOUND_INDEX = 10 +WALL_FACTOR = 0.1 + +@constexpr +def get_full_tensor(shape, fill_value, dtype=np.float32): + return msnp.full(shape, fill_value, dtype) + + +class Controller: + '''controller''' + + def __init__(self, args_opt): + self.input_file = args_opt.i + self.initial_coordinates_file = args_opt.c + self.amber_parm = args_opt.amber_parm + self.restrt = args_opt.r + self.mdcrd = args_opt.x + self.mdout = args_opt.o + self.mdbox = args_opt.box + self.meta = args_opt.meta + self.with_box = args_opt.with_box + self.np_iter = args_opt.np_iter + self.command_set = {} + self.md_task = None + self.commands_from_in_file() + + def commands_from_in_file(self): + '''command from in file''' + file = open(self.input_file, 'r') + context = file.readlines() + file.close() + self.md_task = context[0].strip() + for val in context: + if "=" in val: + assert len(val.strip().split("=")) == 2 + flag, value = val.strip().split("=") + value = value.replace(",", '') + flag = flag.replace(" ", "") + if flag not in self.command_set: + self.command_set[flag] = value + else: + print("ERROR COMMAND FILE") + + +class SimulationCybertron(nn.Cell): + '''simulation''' + + def __init__(self, args_opt, network=None): + super().__init__() + self.control = Controller(args_opt) + if self.control.meta: + self.meta = Tensor([1], mstype.int32) + else: + self.meta = Tensor([0], mstype.int32) + self.md_info = MdInformation(self.control) + self.bond = Bond(self.control) + self.angle = Angle(self.control) + self.dihedral = Dihedral(self.control) + self.nb14 = NonBond14(self.control, self.dihedral, self.md_info.atom_numbers) + self.nb_info = NeighborList(self.control, self.md_info.atom_numbers, self.md_info.box_length) + self.lj_info = LennardJonesInformation(self.control, self.md_info.nb.cutoff, self.md_info.sys.box_length) + self.liujian_info = LangevinLiujian(self.control, self.md_info.atom_numbers) + self.pme_method = ParticleMeshEwald(self.control, self.md_info) + self.bond_energy_sum = Tensor(0, mstype.int32) + self.angle_energy_sum = Tensor(0, mstype.int32) + self.dihedral_energy_sum = Tensor(0, mstype.int32) + self.nb14_lj_energy_sum = Tensor(0, mstype.int32) + self.nb14_cf_energy_sum = Tensor(0, mstype.int32) + self.lj_energy_sum = Tensor(0, mstype.int32) + self.ee_ene = Tensor(0, mstype.int32) + self.total_energy = Tensor(0, mstype.int32) + # Init scalar + self.ntwx = self.md_info.ntwx + self.atom_numbers = self.md_info.atom_numbers + self.residue_numbers = self.md_info.residue_numbers + self.bond_numbers = self.bond.bond_numbers + self.angle_numbers = self.angle.angle_numbers + self.dihedral_numbers = self.dihedral.dihedral_numbers + self.nb14_numbers = self.nb14.nb14_numbers + self.nxy = self.nb_info.nxy + self.grid_numbers = self.nb_info.grid_numbers + self.max_atom_in_grid_numbers = self.nb_info.max_atom_in_grid_numbers + self.max_neighbor_numbers = self.nb_info.max_neighbor_numbers + self.excluded_atom_numbers = self.nb_info.excluded_atom_numbers + self.refresh_count = Parameter(Tensor(self.nb_info.refresh_count, mstype.int32), requires_grad=False) + self.refresh_interval = self.nb_info.refresh_interval + self.skin = self.nb_info.skin + self.cutoff = self.nb_info.cutoff + self.cutoff_square = self.nb_info.cutoff_square + self.cutoff_with_skin = self.nb_info.cutoff_with_skin + self.half_cutoff_with_skin = self.nb_info.half_cutoff_with_skin + self.cutoff_with_skin_square = self.nb_info.cutoff_with_skin_square + self.half_skin_square = self.nb_info.half_skin_square + self.beta = self.pme_method.beta + self.fftx = self.pme_method.fftx + self.ffty = self.pme_method.ffty + self.fftz = self.pme_method.fftz + self.random_seed = self.liujian_info.random_seed + self.dt = self.liujian_info.dt + self.half_dt = self.liujian_info.half_dt + self.exp_gamma = self.liujian_info.exp_gamma + + self.tmp_forces = Tensor(np.zeros((self.atom_numbers, 3)), dtype=mstype.float32) + + self.bias_potential = Parameter(Tensor(WALL_POTENTIAL, mstype.float32), requires_grad=True) + self.grid_num = 200 + self.wall_potential = WALL_P + + self.meta_interval = 5 + + self.wall_factor = WALL_FACTOR + self.upper_bound_index = UPPER_BOUND_INDEX + self.lower_bound_index = LOWER_BOUND_INDEX + self.kappa = KAPPA + self.smin = (Tensor(SMIN, mstype.float32),) + self.smax = SMAX + self.t = T + self.alpha = ALPHA + self.gamma = GAMMA + self.ds = DS + self.ddt = DDT + self.sum = ops.ReduceSum() + self.omega = OMEGA + self.sigma = Tensor(SIGMA, mstype.float32) + self.exp = ops.Exp() + self.square = ops.Square() + self.sqrt = ops.Sqrt() + self.zeros = ops.Zeros() + self.ones = ops.Ones() + self.norm = nn.Norm() + self.add = ops.Add() + self.cast = ops.Cast() + self.cv_list = Tensor(np.arange(SMIN, SMAX, DS, dtype=np.float32)[0:self.grid_num], dtype=mstype.float32) + self.init_tensor() + self.op_define() + self.update = False + self.constant_random_force = Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32) + self.max_vel = 20 + self.hsigmoid = nn.HSigmoid() + self.one_hill = Tensor([1], mstype.int32) + self.sqrt2 = Tensor(np.sqrt(2), mstype.float32) + self.kb = units.boltzmann() + self.kbt = self.kb * self.t + self.beta = 1.0 / self.kbt + self.wt_factor = -1.0 / (self.gamma - 1.0) * self.beta + + self.network = network + self.index_add = ops.IndexAdd(axis=-1) + self.bias = Bias + self.keep_sum = P.ReduceSum(keep_dims=True) + self.grad = C.GradOperation() + self.squeeze = P.Squeeze(0) + + def init_tensor(self): + '''init tensor''' + self.hills = Parameter(Tensor(np.zeros(self.grid_num), mstype.float32), requires_grad=False) + self.crd = Parameter( + Tensor(np.float32(np.asarray(self.md_info.coordinate).reshape([self.atom_numbers, 3])), mstype.float32), + requires_grad=False) + self.crd_to_uint_crd_cof = Tensor(np.asarray(self.md_info.pbc.crd_to_uint_crd_cof, np.float32), mstype.float32) + self.uint_dr_to_dr_cof = Parameter( + Tensor(np.asarray(self.md_info.pbc.uint_dr_to_dr_cof, np.float32), mstype.float32), requires_grad=False) + self.box_length = Tensor(self.md_info.box_length, mstype.float32) + self.virtual_box_length = Tensor([0., 0., 0.], mstype.float32) + self.charge = Parameter(Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32), + requires_grad=False) + self.old_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32), + requires_grad=False) + self.last_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32), + requires_grad=False) + self.uint_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32), + requires_grad=False) + self.mass_inverse = Tensor(self.md_info.h_mass_inverse, mstype.float32) + self.res_start = Tensor(self.md_info.h_res_start, mstype.int32) + self.res_end = Tensor(self.md_info.h_res_end, mstype.int32) + self.mass = Tensor(self.md_info.h_mass, mstype.float32) + self.velocity = Parameter(Tensor(self.md_info.velocity, mstype.float32), requires_grad=False) + self.acc = Parameter(Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32), requires_grad=False) + self.grid_n = Tensor(self.nb_info.grid_n, mstype.int32) + self.grid_length_inverse = Tensor(self.nb_info.grid_length_inverse, mstype.float32) + self.bucket = Parameter(Tensor( + np.asarray(self.nb_info.bucket, np.int32).reshape([self.grid_numbers, self.max_atom_in_grid_numbers]), + mstype.int32), requires_grad=False) + self.atom_numbers_in_grid_bucket = Parameter(Tensor(self.nb_info.atom_numbers_in_grid_bucket, mstype.int32), + requires_grad=False) + self.atom_in_grid_serial = Parameter(Tensor(np.zeros([self.nb_info.atom_numbers,], np.int32), mstype.int32), + requires_grad=False) + self.pointer = Parameter( + Tensor(np.asarray(self.nb_info.pointer, np.int32).reshape([self.grid_numbers, 125]), mstype.int32), + requires_grad=False) + self.nl_atom_numbers = Parameter(Tensor(np.zeros([self.atom_numbers,], np.int32), mstype.int32), + requires_grad=False) + self.nl_atom_serial = Parameter( + Tensor(np.zeros([self.atom_numbers, self.max_neighbor_numbers], np.int32), mstype.int32), + requires_grad=False) + + self.excluded_list_start = Tensor(np.asarray(self.nb_info.excluded_list_start, np.int32), mstype.int32) + self.excluded_list = Tensor(np.asarray(self.nb_info.excluded_list, np.int32), mstype.int32) + self.excluded_numbers = Tensor(np.asarray(self.nb_info.excluded_numbers, np.int32), mstype.int32) + self.need_refresh_flag = Tensor(np.asarray([0], np.int32), mstype.int32) + self.sqrt_mass = Tensor(self.liujian_info.h_sqrt_mass, mstype.float32) + self.rand_state = Parameter(Tensor(self.liujian_info.rand_state, mstype.float32)) + self.zero_fp_tensor = Tensor(np.asarray([0,], np.float32)) + + def op_define(self): + '''op define''' + self.mdtemp = P.MDTemperature(self.residue_numbers, self.atom_numbers) + self.setup_random_state = P.MDIterationSetupRandState(self.atom_numbers, self.random_seed) + + self.md_iteration_leap_frog_liujian = P.MDIterationLeapFrogLiujian(self.atom_numbers, self.half_dt, self.dt, + self.exp_gamma) + + self.neighbor_list_update_init = P.NeighborListUpdate(grid_numbers=self.grid_numbers, + atom_numbers=self.atom_numbers, not_first_time=0, + nxy=self.nxy, + excluded_atom_numbers=self.excluded_atom_numbers, + cutoff_square=self.cutoff_square, + half_skin_square=self.half_skin_square, + cutoff_with_skin=self.cutoff_with_skin, + half_cutoff_with_skin=self.half_cutoff_with_skin, + cutoff_with_skin_square=self.cutoff_with_skin_square, + refresh_interval=self.refresh_interval, + cutoff=self.cutoff, skin=self.skin, + max_atom_in_grid_numbers=self.max_atom_in_grid_numbers, + max_neighbor_numbers=self.max_neighbor_numbers) + + self.random_force = Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32) + + def update_hills(self, index, value): + hills = ops.TensorScatterAdd()(self.hills, + F.expand_dims(index, -1), + F.expand_dims(self.cast(value, mstype.float32), -1)) + return hills + + def simulation_caculate_cybertron_force(self, positions, step, atom_types=None): + """simulation_caculate_cybertron_force""" + forces = -1 * self.grad(self.network)(positions, + atom_types, + None, + None) + cv = self.norm(self.add(self.last_crd[11], -self.last_crd[14])) + cv_index = self.cast((cv - self.smin) / self.ds, mstype.int32) + cv_index = cv_index * (cv_index >= 0) + cv_index = cv_index * (cv_index < self.grid_num) + (self.grid_num - 1) * (cv_index >= self.grid_num) + self.hills = self.update_hills(cv_index, step % self.meta_interval == 0) + bias_cell = self.bias(self.hills, + smin=self.smin, + smax=self.smax, + ds=self.ds, + omega=self.omega, + sigma=self.sigma, + dt=self.ddt, + t=self.t, + alpha=self.alpha, + gamma=self.gamma, + wall_potential=self.wall_potential, + kappa=self.kappa, + upper_bound=self.upper_bound_index, + lower_bound=self.lower_bound_index, + factor=self.wall_factor) + entropy_force = self.grad(bias_cell)(self.last_crd) + tforces = P.AddN()([self.squeeze(forces), -self.meta * entropy_force]) + return tforces + + def simulation_caculate_cybertron_energy(self, positions, atom_types=None): + energy = self.network(positions, atom_types, None, None) + energy = self.squeeze(energy) + return energy + + def simulation_temperature(self): + '''caculate temperature''' + res_ek_energy = self.mdtemp(self.res_start, self.res_end, self.velocity, self.mass) + temperature = P.ReduceSum()(res_ek_energy) + return temperature + + def simulation_mditeration_leapfrog_liujian(self, inverse_mass, sqrt_mass_inverse, crd, frc, rand_state, + random_frc): + '''simulation leap frog iteration liujian''' + crd = self.md_iteration_leap_frog_liujian(inverse_mass, sqrt_mass_inverse, self.velocity, crd, frc, self.acc, + rand_state, random_frc) + + vel = F.depend(self.velocity, crd) + vel = (self.hsigmoid(vel * 3 / self.max_vel) - 0.5) * 2 * self.max_vel + acc = F.depend(self.acc, crd) + return vel, crd, acc + + def main_print(self, *args): + """compute the temperature""" + steps, temperature, total_potential_energy, _, _, _, _, _, _, _ = list(args) + if steps == 0: + print("_steps_ _TEMP_ _TOT_POT_ENE_ _CVariable_ _Bias_Potential_") + + temperature = temperature.asnumpy() + total_potential_energy = total_potential_energy.asnumpy() + cv = self.norm(self.add(self.last_crd[11], -self.last_crd[14])) + biasp = self.sum( + self.dt * self.hills * self.omega * self.exp(-self.square(cv - self.cv_list) / 2 / self.square(self.sigma))) + print("{:>4.0f} {:>8.3f} {:>10.3f} {:>12.3f} {:>13.3f}" + "".format(steps, float(temperature), float(total_potential_energy), cv.asnumpy(), biasp.asnumpy())) + + if self.file is not None: + self.file.write("{:>4.0f} {:>8.3f} {:>10.3f} {:>12.3f} {:>13.3f}\n" + "".format(steps, float(temperature), float(total_potential_energy), cv.asnumpy(), + biasp.asnumpy())) + if self.datfile is not None: + self.datfile.write(self.crd.asnumpy()) + + def main_initial(self): + """main initial""" + if self.control.mdout: + self.file = open(self.control.mdout, 'w') + self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _CVariable_ _Bias_Potential_\n") + if self.control.mdcrd: + self.datfile = open(self.control.mdcrd, 'wb') + + def main_destroy(self): + """main destroy""" + if self.file is not None: + self.file.close() + print("Save .out file successfully!") + if self.datfile is not None: + self.datfile.close() + print("Save .dat file successfully!") + + def construct(self, step, print_step): + '''construct''' + self.last_crd = self.crd + if step == 0: + res = self.neighbor_list_update_init(self.atom_numbers_in_grid_bucket, self.bucket, self.crd, + self.virtual_box_length, self.grid_n, self.grid_length_inverse, + self.atom_in_grid_serial, self.old_crd, self.crd_to_uint_crd_cof, + self.uint_crd, self.pointer, self.nl_atom_numbers, self.nl_atom_serial, + self.uint_dr_to_dr_cof, self.excluded_list_start, self.excluded_list, + self.excluded_numbers, self.need_refresh_flag, self.refresh_count) + self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res) + self.nl_atom_serial = F.depend(self.nl_atom_serial, res) + self.uint_dr_to_dr_cof = F.depend(self.uint_dr_to_dr_cof, res) + self.old_crd = F.depend(self.old_crd, res) + self.atom_numbers_in_grid_bucket = F.depend(self.atom_numbers_in_grid_bucket, res) + self.bucket = F.depend(self.bucket, res) + self.atom_in_grid_serial = F.depend(self.atom_in_grid_serial, res) + self.pointer = F.depend(self.pointer, res) + + positions = F.expand_dims(self.crd, 0) + force = self.simulation_caculate_cybertron_force(positions, step) + bond_energy_sum = self.zero_fp_tensor + angle_energy_sum = self.zero_fp_tensor + dihedral_energy_sum = self.zero_fp_tensor + nb14_lj_energy_sum = self.zero_fp_tensor + nb14_cf_energy_sum = self.zero_fp_tensor + lj_energy_sum = self.zero_fp_tensor + ee_ene = self.zero_fp_tensor + total_energy = self.simulation_caculate_cybertron_energy(positions) + + temperature = self.simulation_temperature() + self.rand_state = self.setup_random_state() + self.velocity, self.crd, _ = self.simulation_mditeration_leapfrog_liujian(self.mass_inverse, + self.sqrt_mass, self.crd, force, + self.rand_state, + self.random_force) + + res = self.ds + self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res) + self.nl_atom_serial = F.depend(self.nl_atom_serial, res) + else: + + positions = F.expand_dims(self.crd, 0) + force = self.simulation_caculate_cybertron_force(positions, step) + if print_step == 0: + bond_energy_sum = self.zero_fp_tensor + angle_energy_sum = self.zero_fp_tensor + dihedral_energy_sum = self.zero_fp_tensor + nb14_lj_energy_sum = self.zero_fp_tensor + nb14_cf_energy_sum = self.zero_fp_tensor + lj_energy_sum = self.zero_fp_tensor + ee_ene = self.zero_fp_tensor + total_energy = self.simulation_caculate_cybertron_energy(positions) + else: + bond_energy_sum = self.zero_fp_tensor + angle_energy_sum = self.zero_fp_tensor + dihedral_energy_sum = self.zero_fp_tensor + nb14_lj_energy_sum = self.zero_fp_tensor + nb14_cf_energy_sum = self.zero_fp_tensor + lj_energy_sum = self.zero_fp_tensor + ee_ene = self.zero_fp_tensor + total_energy = self.zero_fp_tensor + temperature = self.simulation_temperature() + self.velocity, self.crd, _ = self.simulation_mditeration_leapfrog_liujian(self.mass_inverse, + self.sqrt_mass, self.crd, force, + self.rand_state, + self.random_force) + + res = self.ds + self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res) + self.nl_atom_serial = F.depend(self.nl_atom_serial, res) + return temperature, total_energy, bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, \ + nb14_cf_energy_sum, lj_energy_sum, ee_ene, res diff --git a/MindSPONGE/examples/covid/scripts/run_all.sh b/MindSPONGE/examples/covid/scripts/run_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..c4d90f383018e2c30a94df89fefd0c0990968ea2 --- /dev/null +++ b/MindSPONGE/examples/covid/scripts/run_all.sh @@ -0,0 +1,111 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +mkdir -p min1 +cd min1 +cat > min1.in < min2.in < min3.in < min4.in < heat.in < pres.in << EOF +S3 press + mode = npt + step_limit = 60000 + dt = 2e-3 + constrain_mode = simple_constrain + write_information_interval = 2500 + thermostat = langevin_liu + barostat = berendsen +EOF +python ../../src/run_npt.py --i ./pres.in --amber_parm ../../data/ace2.parm7 --c ../heat/heat.rst7 --r pres.rst7 +cd .. + +mkdir -p product +cd product +cat > md.in << EOF +S4 product + mode = npt + step_limit = 750000 + dt = 4e-3 + constrain_mode = simple_constrain + write_information_interval = 2500 + write_restart_file_interval = 250000 + thermostat = langevin_liu + langevin_liu_velocity_max = 20 + langevin_liu_gamma = 10.0 + barostat = berendsen +EOF +python ../../src/run_npt.py --i ./md.in --amber_parm ../../data/ace2.parm7 --c ../pres/pres.rst7 --r product.rst7 +cd .. + diff --git a/MindSPONGE/examples/covid/scripts/run_npt.sh b/MindSPONGE/mindsponge/md/cybertron/__init__.py similarity index 41% rename from MindSPONGE/examples/covid/scripts/run_npt.sh rename to MindSPONGE/mindsponge/md/cybertron/__init__.py index bdbdc71e5a04374bff7a3e404ae38bd914925a2c..e5107d8089b2bde988e416c107de1afbbcf30993 100644 --- a/MindSPONGE/examples/covid/scripts/run_npt.sh +++ b/MindSPONGE/mindsponge/md/cybertron/__init__.py @@ -1,5 +1,13 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,36 +21,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - -mkdir -p pres -cd pres -cat > pres.in << EOF -S3 press - mode = npt - step_limit = 60000 - dt = 2e-3 - constrain_mode = simple_constrain - write_information_interval = 2500 - thermostat = langevin_liu - barostat = berendsen -EOF -python ../../src/run_npt.py --i ./pres.in --amber_parm ../../data/ace2.parm7 --c ../heat/heat.rst7 --r pres.rst7 -cd .. - -#mkdir -p product -#cd product -#cat > md.in << EOF -#S4 product -# mode = npt -# step_limit = 750000 -# dt = 4e-3 -# constrain_mode = simple_constrain -# write_information_interval = 2500 -# write_restart_file_interval = 250000 -# thermostat = langevin_liu -# langevin_liu_velocity_max = 20 -# langevin_liu_gamma = 10.0 -# barostat = berendsen -#EOF -#python ../../src/run_npt.py --i ./md.in --amber_parm ../../data/ace2.parm7 --c ../pres/pres.rst7 --r product.rst7 -#cd .. diff --git a/MindSPONGE/mindsponge/md/cybertron/activations.py b/MindSPONGE/mindsponge/md/cybertron/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..dd91020ba55fff3ab95ac6094ae4957fec066157 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/activations.py @@ -0,0 +1,129 @@ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""activations""" + +from mindspore import nn +from mindspore.nn.layer.activation import _activation +from mindspore.ops import operations as P + +__all__ = [ + "ShiftedSoftplus", + "ScaledShiftedSoftplus", + "Swish", + "get_activation", + ] + + +class ShiftedSoftplus(nn.Cell): + r"""Compute shifted soft-plus activation function. + + .. math:: + y = \ln\left(1 + e^{-x}\right) - \ln(2) + + Args: + x (mindspore.Tensor): input tensor. + + Returns: + mindspore.Tensor: shifted soft-plus of input. + + """ + def __init__(self): + super().__init__() + self.log1p = P.Log1p() + self.exp = P.Exp() + self.ln2 = 0.6931471805599453 + + def __str__(self): + return "shifted_softplus" + + def construct(self, x): + return self.log1p(self.exp(x)) - self.ln2 + + +class ScaledShiftedSoftplus(nn.Cell): + r"""Compute shifted soft-plus activation function. + + .. math:: + y = \ln\left(1 + e^{-x}\right) - \ln(2) + + Args: + x (mindspore.Tensor): input tensor. + + Returns: + mindspore.Tensor: shifted soft-plus of input. + + """ + def __init__(self): + super().__init__() + self.softplus = P.Softplus() + self.ln2 = 0.6931471805599453 + + def __str__(self): + return "scaled_shifted_softplus" + + def construct(self, x): + return 2 * (self.softplus(x) - self.ln2) + + +class Swish(nn.Cell): + r"""Compute swish\SILU\SiL function. + + .. math:: + y_i = x_i / (1 + e^{-beta * x_i}) + + Args: + x (mindspore.Tensor): input tensor. + + Returns: + mindspore.Tensor: shifted soft-plus of input. + + """ + def __init__(self): + super().__init__() + self.sigmoid = nn.Sigmoid() + + def __str__(self): + return "swish" + + def construct(self, x): + return x * self.sigmoid(x) + + +_EXTENDED_ACTIVATIONS = { + 'shifted': ShiftedSoftplus, + 'scaledshifted': ScaledShiftedSoftplus, + 'swish': Swish, +} + + +def get_activation(name): + """get activation""" + if name is None or isinstance(name, nn.Cell): + return name + if isinstance(name, str): + if name.lower() in _activation.keys(): + return name + if name.lower() not in _EXTENDED_ACTIVATIONS.keys(): + raise ValueError("The class corresponding to '{}' was not found.".format(name)) + return _EXTENDED_ACTIVATIONS[name.lower()]() + raise TypeError("Unsupported activation type '{}'.".format(type(name))) diff --git a/MindSPONGE/mindsponge/md/cybertron/aggregators.py b/MindSPONGE/mindsponge/md/cybertron/aggregators.py new file mode 100644 index 0000000000000000000000000000000000000000..a68cbc7900391bec3d5f089eeb79939860eb2e99 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/aggregators.py @@ -0,0 +1,439 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""aggregators""" + +import mindspore as ms +from mindspore import nn +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindspore.common.initializer import Normal + +from cybertron.blocks import MLP, Dense +from cybertron.base import SoftmaxWithMask +from cybertron.base import MultiheadAttention + +__all__ = [ + "Aggregator", + "get_aggregator", + "TensorSummation", + "TensorMean", + "SoftmaxGeneralizedAggregator", + "PowermeanGeneralizedAggregator", + "ListAggregator", + "get_list_aggregator", + "ListSummation", + "ListMean", + "LinearTransformation", + "MultipleChannelRepresentation", +] + +_AGGREGATOR_ALIAS = dict() +_LIST_AGGREGATOR_ALIAS = dict() + + +def _aggregator_register(*aliases): + """Return the alias register.""" + def alias_reg(cls): + name = cls.__name__ + name = name.lower() + if name not in _AGGREGATOR_ALIAS: + _AGGREGATOR_ALIAS[name] = cls + + for alias in aliases: + if alias not in _AGGREGATOR_ALIAS: + _AGGREGATOR_ALIAS[alias] = cls + + return cls + + return alias_reg + + +def _list_aggregator_register(*aliases): + """Return the alias register.""" + def alias_reg(cls): + name = cls.__name__ + name = name.lower() + if name not in _LIST_AGGREGATOR_ALIAS: + _LIST_AGGREGATOR_ALIAS[name] = cls + + for alias in aliases: + if alias not in _LIST_AGGREGATOR_ALIAS: + _LIST_AGGREGATOR_ALIAS[alias] = cls + + return cls + + return alias_reg + + +class Aggregator(nn.Cell): + def __init__(self, dim=None, axis=-2,): + super().__init__() + + self.name = 'aggregator' + + self.dim = dim + self.axis = axis + self.reduce_sum = P.ReduceSum() + + +class ListAggregator(nn.Cell): + """list aggretor""" + def __init__(self, dim=None, num_agg=None, n_hidden=0, activation=None,): + super().__init__() + + self.dim = dim + self.num_agg = num_agg + self.n_hidden = n_hidden + self.activation = activation + + self.stack = P.Stack(-1) + self.reduce_sum = P.ReduceSum() + + +@_aggregator_register('sum') +class TensorSummation(Aggregator): + """tensor summation""" + def __init__(self, dim=None, axis=-2,): + super().__init__(dim=None, axis=axis,) + + self.name = 'sum' + + def __str__(self): + return "sum" + + def construct(self, nodes, node_mask=None): + if node_mask is not None: + nodes = nodes * node_mask + agg = self.reduce_sum(nodes, self.axis) + return agg + + +@_aggregator_register('mean') +class TensorMean(Aggregator): + """tensor mean""" + def __init__(self, dim=None, axis=-2,): + super().__init__(dim=None, axis=axis,) + self.name = 'mean' + + self.reduce_mean = P.ReduceMean() + self.mol_sum = P.ReduceSum(keep_dims=True) + + def __str__(self): + return "mean" + + def construct(self, nodes, node_mask=None, nodes_number=None): + if node_mask is not None: + nodes = nodes * node_mask + agg = self.reduce_sum(nodes, self.axis) + return agg / nodes_number + return self.reduce_mean(nodes, self.axis) + + +# Softmax-based generalized mean-max-sum aggregator +@_aggregator_register('softmax') +class SoftmaxGeneralizedAggregator(Aggregator): + """softmax generalized aggregator""" + def __init__(self, dim, axis=-2,): + super().__init__(dim=dim, axis=axis,) + + self.name = 'softmax' + + self.beta = ms.Parameter(initializer('one', 1), name="beta") + self.rho = ms.Parameter(initializer('zero', 1), name="rho") + + self.softmax = P.Softmax(axis=self.axis) + self.softmax_with_mask = SoftmaxWithMask(axis=self.axis) + self.mol_sum = P.ReduceSum(keep_dims=True) + + self.expand_ones = P.Ones()((1, 1, self.dim), ms.int32) + + def __str__(self): + return "softmax" + + def construct(self, nodes, node_mask=None, nodes_number=None): + """construct""" + if nodes_number is None: + nodes_number = nodes.shape[self.axis] + + scale = nodes_number / (1 + self.beta * (nodes_number - 1)) + px = nodes * self.rho + + if node_mask is None: + agg_nodes = self.softmax(px) * nodes + else: + mask = (self.expand_ones * node_mask) > 0 + agg_nodes = self.softmax_with_mask(px, mask) * nodes * node_mask + + agg_nodes = self.reduce_sum(agg_nodes, self.axis) + + return scale * agg_nodes + + +# PowerMean-based generalized mean-max-sum aggregator +@_aggregator_register('powermean') +class PowermeanGeneralizedAggregator(Aggregator): + """power mean generalized aggregator""" + def __init__(self, dim, axis=-2,): + super().__init__(dim=dim, axis=axis,) + self.name = 'powermean' + self.beta = ms.Parameter(initializer('one', 1), name="beta") + self.rho = ms.Parameter(initializer('one', 1), name="rho") + + self.power = P.Pow() + self.mol_sum = P.ReduceSum(keep_dims=True) + + def __str__(self): + return "powermean" + + def construct(self, nodes, node_mask=None, nodes_number=None): + """construct""" + if nodes_number is None: + nodes_number = nodes.shape[self.axis] + + scale = nodes_number / (1 + self.beta * (nodes_number - 1)) + xp = self.power(nodes, self.rho) + if node_mask is not None: + xp = xp * node_mask + agg_nodes = self.reduce_sum(xp, self.axis) + + return self.power(scale * agg_nodes, 1.0 / self.rho) + + +@_aggregator_register('transformer') +class TransformerAggregator(Aggregator): + """trasnformer aggregator""" + def __init__(self, dim, axis=-2, n_heads=8,): + super().__init__( + dim=dim, + axis=axis, + ) + + self.name = 'transformer' + + self.a2q = Dense(dim, dim, has_bias=False) + self.a2k = Dense(dim, dim, has_bias=False) + self.a2v = Dense(dim, dim, has_bias=False) + + self.layer_norm = nn.LayerNorm((dim,), -1, -1) + + self.multi_head_attention = MultiheadAttention( + dim, n_heads, dim_tensor=3) + + self.squeeze = P.Squeeze(-1) + self.mean = TensorMean(dim, axis) + + def __str__(self): + return "transformer" + + def construct(self, nodes, node_mask=None, nodes_number=None): + r"""Transformer type aggregator. + + Args: + nodes (Mindspore.Tensor[float] [B, A, F]): + + Returns: + Mindspore.Tensor [..., F]: multi-head attention output. + + """ + # [B, A, F] + x = self.layer_norm(nodes) + + # [B, A, F] + q = self.a2q(x) + k = self.a2k(x) + v = self.a2v(x) + + if node_mask is not None: + mask = self.squeeze(node_mask) + else: + mask = node_mask + + # [B, A, F] + x = self.multi_head_attention(q, k, v, mask) + + # [B, 1, F] + return self.mean(x, node_mask, nodes_number) + + +@_list_aggregator_register('sum') +class ListSummation(ListAggregator): + """list summation""" + def __init__(self, + dim=None, + num_agg=None, + n_hidden=0, + activation=None, + ): + super().__init__( + dim=None, + num_agg=None, + n_hidden=0, + activation=None, + ) + + def __str__(self): + return "sum" + + def construct(self, xlist, node_mask=None): + xt = self.stack(xlist) + y = self.reduce_sum(xt, -1) + if node_mask is not None: + y = y * node_mask + return y + + +@_list_aggregator_register('mean') +class ListMean(ListAggregator): + """list mean""" + def __init__(self, dim=None, num_agg=None, n_hidden=0, activation=None,): + super().__init__( + dim=None, + num_agg=None, + n_hidden=0, + activation=None, + ) + + self.reduce_mean = P.ReduceMean() + + def __str__(self): + return "mean" + + def construct(self, xlist, node_mask=None): + xt = self.stack(xlist) + y = self.reduce_mean(xt, -1) + if node_mask is not None: + y = y * node_mask + return y + + +@_list_aggregator_register('linear') +class LinearTransformation(ListAggregator): + """linear transformation""" + def __init__(self, + dim, + num_agg=None, + n_hidden=0, + activation=None, + ): + super().__init__( + dim=dim, + num_agg=None, + n_hidden=0, + activation=None, + ) + self.scale = ms.Parameter( + initializer( + Normal(1.0), [self.dim,]), name="scale") + self.shift = ms.Parameter( + initializer( + Normal(1.0), [self.dim,]), name="shift") + + def __str__(self): + return "linear" + + def construct(self, ylist, node_mask=None): + yt = self.stack(ylist) + ysum = self.reduce_sum(yt, -1) + y = self.scale * ysum + self.shift + if node_mask is not None: + y = y * node_mask + return y + +# Multiple-Channel Representation Readout + + +@_list_aggregator_register('mcr') +class MultipleChannelRepresentation(ListAggregator): + """multiple channel representation""" + def __init__(self, + dim, + num_agg, + n_hidden=0, + activation=None, + ): + super().__init__( + dim=dim, + num_agg=num_agg, + n_hidden=n_hidden, + activation=activation, + ) + + sub_dim = self.dim // self.num_agg + last_dim = self.dim - (sub_dim * (self.num_agg - 1)) + sub_dims = [sub_dim for _ in range(self.num_agg - 1)] + sub_dims.append(last_dim) + + if self.n_hidden > 0: + hidden_layers = [dim,] * self.n_hidden + self.mcr = nn.CellList([ + MLP(self.dim, sub_dims[i], hidden_layers, activation=self.activation) + for i in range(self.um_agg) + ]) + else: + self.mcr = nn.CellList([ + Dense(self.dim, sub_dims[i], activation=self.activation) + for i in range(self.num_agg) + ]) + + self.concat = P.Concat(-1) + + def __str__(self): + return "MCR" + + def construct(self, xlist, node_mask=None): + xt = () + for i in range(self.num_agg): + xt = xt + (self.mcr[i](xlist[i]),) + y = self.concat(xt) + if node_mask is not None: + y = y * node_mask + return y + + +def get_aggregator(obj, dim, axis=-2): + if obj is None or isinstance(obj, Aggregator): + return obj + if isinstance(obj, str): + if obj.lower() not in _AGGREGATOR_ALIAS.keys(): + raise ValueError( + "The class corresponding to '{}' was not found.".format(obj)) + return _AGGREGATOR_ALIAS[obj.lower()](dim=dim, axis=axis) + raise TypeError("Unsupported Aggregator type '{}'.".format(type(obj))) + + +def get_list_aggregator(obj, dim, num_agg, n_hidden=0, activation=None,): + """get list aggregator""" + if obj is None or isinstance(obj, ListAggregator): + return obj + if isinstance(obj, str): + if obj.lower() not in _LIST_AGGREGATOR_ALIAS.keys(): + raise ValueError( + "The class corresponding to '{}' was not found.".format(obj)) + return _LIST_AGGREGATOR_ALIAS[obj.lower()]( + dim=dim, + num_agg=num_agg, + n_hidden=n_hidden, + activation=activation, + ) + raise TypeError("Unsupported ListAggregator type '{}'.".format(type(obj))) diff --git a/MindSPONGE/mindsponge/md/cybertron/base.py b/MindSPONGE/mindsponge/md/cybertron/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9cfa1869ed0a9fe9070acf12f341b3381de91c --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/base.py @@ -0,0 +1,755 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""cybertron.base""" + +import numpy as np +import mindspore as ms +from mindspore import nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer, Constant + +from cybertron.units import units +from cybertron.blocks import MLP, Dense, Residual +from cybertron.cutoff import SmoothCutoff + +__all__ = [ + "GraphNorm", + "Filter", + "ResFilter", + "CFconv", + "Aggregate", + "SmoothReciprocal", + "SoftmaxWithMask", + "PositionalEmbedding", + "MultiheadAttention", + "FeedForward", + "Pondering", + "ACTWeight", + "Num2Mask", + "Number2FullConnectNeighbors", + "Types2FullConnectNeighbors", +] + + +class GraphNorm(nn.Cell): + """graph norm""" + def __init__(self, + dim_feature, + node_axis=-2, + alpha_init='one', + beta_init='zero', + gamma_init='one' + ): + super().__init__() + self.alpha = Parameter( + initializer( + alpha_init, + dim_feature), + name="alpha") + self.beta = Parameter(initializer(beta_init, dim_feature), name="beta") + self.gamma = Parameter( + initializer( + gamma_init, + dim_feature), + name="gamma") + + self.axis = node_axis + + self.reduce_mean = P.ReduceMean(keep_dims=True) + + self.sqrt = P.Sqrt() + + def construct(self, nodes): + """construct""" + mu = self.reduce_mean(nodes, self.axis) + + nodes2 = nodes * nodes + mu2 = self.reduce_mean(nodes2, self.axis) + + a = self.alpha + sigma2 = mu2 + (a * a - 2 * a) * mu * mu + sigma = self.sqrt(sigma2) + + y = self.gamma * (nodes - a * mu) / sigma + self.beta + + return y + + +class Filter(nn.Cell): + """filter""" + def __init__(self, + num_rbf, + dim_filter, + activation, + n_hidden=1, + ): + super().__init__() + + if n_hidden > 0: + hidden_layers = [dim_filter for _ in range(n_hidden)] + self.dense_layers = MLP( + num_rbf, + dim_filter, + hidden_layers, + activation=activation) + else: + self.dense_layers = Dense( + num_rbf, dim_filter, activation=activation) + + def construct(self, rbf): + """construct""" + return self.dense_layers(rbf) + + +class ResFilter(nn.Cell): + """resgilter""" + def __init__(self, + num_rbf, + dim_filter, + activation, + n_hidden=1, + ): + super().__init__() + + self.linear = Dense(num_rbf, dim_filter, activation=None) + self.residual = Residual( + dim_filter, + activation=activation, + n_hidden=n_hidden) + + def construct(self, x): + """construct""" + lx = self.linear(x) + return self.residual(lx) + + +class CFconv(nn.Cell): + """CFcony""" + def __init__(self, num_rbf, dim_filter, activation,): + super().__init__() + # filter block used in interaction block + self.filter = Filter(num_rbf, dim_filter, activation) + + def construct(self, x, f_ij, c_ij=None): + w = self.filter(f_ij) + if c_ij is not None: + w = w * F.expand_dims(c_ij, -1) + + return x * w + + +class Aggregate(nn.Cell): + """Pooling layer based on sum or average with optional masking. + + Args: + axis (int): axis along which pooling is done. + mean (bool, optional): if True, use average instead for sum pooling. + + """ + + def __init__(self, axis, mean=False): + super().__init__() + self.average = mean + self.axis = axis + self.reduce_sum = P.ReduceSum() + self.maximum = P.Maximum() + + def construct(self, inputs, mask=None): + r"""Compute layer output. + + Args: + input (torch.Tensor): input data. + mask (torch.Tensor, optional): mask to be applied; e.g. neighbors mask. + + Returns: + torch.Tensor: layer output. + + """ + # mask input + if mask is not None: + inputs = inputs * F.expand_dims(mask, -1) + # compute sum of input along axis + + y = self.reduce_sum(inputs, self.axis) + # compute average of input along axis + if self.average: + # get the number of items along axis + if mask is not None: + n = self.reduce_sum(mask, self.axis) + n = self.maximum(n, other=F.ones_like(n)) + else: + n = inputs.shape[self.axis] + + y = y / n + return y + + +class SmoothReciprocal(nn.Cell): + """SmoothReciprocal""" + def __init__(self, + dmax=units.length(1, 'nm'), + cutoff_network=None + ): + super().__init__() + + if cutoff_network is None: + self.cutoff_network = SmoothCutoff(dmax, return_mask=False) + else: + self.cutoff_network = cutoff_network(dmax, return_mask=False) + + self.sqrt = P.Sqrt() + + def construct(self, rij, mask): + """construct""" + phi2rij = self.cutoff_network(rij * 2, mask) + + r_near = phi2rij * (1.0 / self.sqrt(rij * rij + 1.0)) + r_far = F.select(rij > 0, (1.0 - phi2rij) * + (1.0 / rij), F.zeros_like(rij)) + + reciprocal = r_near + r_far + if mask is not None: + reciprocal = reciprocal * mask + + return reciprocal + + +class SoftmaxWithMask(nn.Cell): + """SoftmaxWithMask""" + def __init__(self, axis=-1): + super().__init__() + self.softmax = P.Softmax(axis) + + self.large_neg = -5e4 + + def construct(self, x, mask): + large_neg = F.ones_like(x) * self.large_neg + xm = F.select(mask, x, large_neg) + + return self.softmax(xm) + + +class PositionalEmbedding(nn.Cell): + """PositionalEmbedding""" + def __init__( + self, + dim, + use_distances=True, + use_bonds=False, + use_public_layer_norm=True): + super().__init__() + + if not (use_bonds or use_distances): + raise ValueError( + '"use_bonds" and "use_distances" cannot be both "False" when initializing "PositionalEmbedding"!') + + self.use_distances = use_distances + self.use_bonds = use_bonds + + if use_public_layer_norm: + self.norm = nn.LayerNorm((dim,), -1, -1) + self.norm_q = self.norm + self.norm_k = self.norm + self.norm_v = self.norm + else: + self.norm_q = nn.LayerNorm((dim,), -1, -1) + self.norm_k = nn.LayerNorm((dim,), -1, -1) + self.norm_v = nn.LayerNorm((dim,), -1, -1) + + self.x2q = Dense(dim, dim, has_bias=False) + self.x2k = Dense(dim, dim, has_bias=False) + self.x2v = Dense(dim, dim, has_bias=False) + + self.mul = P.Mul() + self.concat = P.Concat(-2) + + def construct( + self, + xi, + xij, + g_ii=1, + g_ij=1, + b_ii=0, + b_ij=0, + c_ij=None, + t=0): + r"""Get query, key and query from atom types and positions + + Args: + xi (Mindspore.Tensor [B, A, F]): + g_ii (Mindspore.Tensor [B, A, F]): + xij (Mindspore.Tensor [B, A, N, F]): + g_ij (Mindspore.Tensor [B, A, N, F]): + t (Mindspore.Tensor [F]): + + Marks: + B: Batch size + A: Number of atoms + N: Number of neighbor atoms + N': Number of neighbor atoms and itself (N' = N + 1) + F: Dimensions of feature space + + Returns: + query (Mindspore.Tensor [B, A, 1, F]): + key (Mindspore.Tensor [B, A, N', F]): + value (Mindspore.Tensor [B, A, N', F]): + + """ + + if self.use_distances: + # [B, A, F] * [B, A, F] + [B, A, F] = [B, A, F] + a_ii = xi * g_ii + # [B, A, N, F] * [B, A, N, F] + [B, A, N, F] = [B, A, N, F] + a_ij = xij * g_ij + else: + a_ii = xi + a_ij = xij + + if self.use_bonds: + e_ii = a_ii + b_ii + e_ij = a_ij + b_ij + else: + e_ii = a_ii + e_ij = a_ij + + # [B, A, 1, F] + e_ii = F.expand_dims(e_ii, -2) + # [B, A, N', F] + [B, A, N', F] + e_ij = self.concat((e_ii, e_ij)) + + xq = self.norm_q(e_ii + t) + xk = self.norm_k(e_ij + t) + xv = self.norm_v(e_ij + t) + # [B, A, 1, F] + query = self.x2q(xq) + # [B, A, N', F] + key = self.x2k(xk) + # [B, A, N', F] + value = self.x2v(xv) + + if c_ij is not None: + # [B, A, N', F] * [B, A, N', 1] + key = key * F.expand_dims(c_ij, -1) + value = value * F.expand_dims(c_ij, -1) + + return query, key, value + + +class MultiheadAttention(nn.Cell): + r"""Compute multi-head attention. + + Args: + dim_feature (int): Diension of feature space (F) + n_heads (int): Number of heads (h) + dim_tensor (int): Dimension of input tensor (D) + + Signs: + X: Dimension to be aggregated + F: Dimension of Feature space + h: Number of heads for multi-head attention + f: Dimensions per head (F = f * h) + + """ + + def __init__(self, dim_feature, n_heads=8, dim_tensor=4): + super().__init__() + + # D + if dim_tensor < 2: + raise ValueError('dim_tensor must be larger than 1') + + # h + self.n_heads = n_heads + + # f = F / h + self.size_per_head = dim_feature // n_heads + # 1.0 / sqrt(f) + scores_mul = 1.0 / np.sqrt(float(self.size_per_head)) + self.scores_mul = ms.Tensor(scores_mul, ms.float32) + + # shape = (h, f) + self.reshape_tail = (self.n_heads, self.size_per_head) + + self.output = Dense(dim_feature, dim_feature, has_bias=False) + + self.mul = P.Mul() + self.div = P.Div() + self.softmax = P.Softmax() + self.bmm = P.BatchMatMul() + self.bmmt = P.BatchMatMul(transpose_b=True) + self.reducesum = P.ReduceSum(keep_dims=True) + + # [0,1,...,D-1] + ranges = list(range(dim_tensor + 1)) + tmpid = ranges[-2] + ranges[-2] = ranges[-3] + ranges[-3] = tmpid + # [0,1,...,D-2,D-3,D-1] + self.trans_shape = tuple(ranges) + self.transpose = P.Transpose() + + self.softmax_with_mask = SoftmaxWithMask() + + def construct(self, query, key, value, mask=None, cutoff=None): + r"""Compute multi-head attention. + + Args: + query (Mindspore.Tensor [..., 1, F] or [..., X, F]): + key (Mindspore.Tensor [..., X, F]): + value (Mindspore.Tensor [..., X, F]): + mask (Mindspore.Tensor [..., X]): + cutoff (Mindspore.Tensor [..., X]): + + Returns: + Mindspore.Tensor [..., F]: multi-head attention output. + + """ + if self.n_heads > 1: + q_reshape = query.shape[:-1] + self.reshape_tail + k_reshape = key.shape[:-1] + self.reshape_tail + v_reshape = value.shape[:-1] + self.reshape_tail + + # [..., 1, h, f] or [..., X, h, f] + q = F.reshape(query, q_reshape) + # [..., h, 1, f] or [..., h, X, f] + q = self.transpose(q, self.trans_shape) + + # [..., X, h, f] + k = F.reshape(key, k_reshape) + # [..., h, X, f] + k = self.transpose(k, self.trans_shape) + + # [..., X, h, f] + v = F.reshape(value, v_reshape) + # [..., h, X, f] + v = self.transpose(v, self.trans_shape) + + # [..., h, 1, f] x [..., h, X, f]^T = [..., h, 1, X] + # or + # [..., h, X, f] x [..., h, X, f]^T = [..., h, X, X] + attention_scores = self.bmmt(q, k) + # ([..., h, 1, X] or [..., h, X, X]) / sqrt(f) + attention_scores = self.mul(attention_scores, self.scores_mul) + + if mask is None: + # [..., h, 1, X] or [..., h, X, X] + attention_probs = self.softmax(attention_scores) + else: + # [..., X] -> [..., 1, 1, X] + exmask = F.expand_dims(F.expand_dims(mask, -2), -2) + # [..., 1, 1, X] -> ([..., h, 1, X] or [..., h, X, X]) + mhmask = (exmask * F.ones_like(attention_scores)) > 0 + # [..., h, 1, X] or [..., h, X, X] + attention_probs = self.softmax_with_mask( + attention_scores, mhmask) + + if cutoff is not None: + # [..., X] -> [..., 1, 1, X] + excut = F.expand_dims(F.expand_dims(cutoff, -2), -2) + attention_probs = self.mul(attention_probs, excut) + + context = self.bmm(attention_probs, v) + # [..., 1, h, f] or [..., X, h, f] + context = self.transpose(context, self.trans_shape) + # [..., 1, F] or [..., X, F] + context = F.reshape(context, query.shape) + + else: + attention_scores = self.bmmt(query, key) * self.scores_mul + + if mask is None: + # [..., 1, X] or [..., X, X] + attention_probs = self.softmax(attention_scores) + else: + # [..., X] -> [..., 1, X] + mask = F.expand_dims(mask, -2) + # [..., 1, X] + attention_probs = self.softmax_with_mask( + attention_scores, mask) + + if cutoff is not None: + # [..., 1, X] * [..., 1, X] + attention_probs = attention_probs * \ + F.expand_dims(cutoff, -2) + + # [..., 1, X] x [..., X, F] = [..., 1, F] + # or + # [..., X, X] x [..., X, F] = [..., X, F] + context = self.bmm(attention_probs, value) + + # [..., 1, F] or [..., X, F] + return self.output(context) + + +class FeedForward(nn.Cell): + def __init__(self, dim, activation, n_hidden=1): + super().__init__() + + self.norm = nn.LayerNorm((dim,), -1, -1) + self.residual = Residual(dim, activation=activation, n_hidden=n_hidden) + + def construct(self, x): + nx = self.norm(x) + return self.residual(nx) + + +class Pondering(nn.Cell): + """Pondering""" + def __init__(self, n_in, n_hidden=0, bias_const=1.): + super().__init__() + + if n_hidden == 0: + self.dense = nn.Dense( + n_in, + 1, + has_bias=True, + weight_init='xavier_uniform', + bias_init=Constant(bias_const), + activation='sigmoid', + ) + elif n_hidden > 0: + nets = [] + for _ in range(n_hidden): + nets.append(nn.Dense(n_in, n_in, weight_init='xavier_uniform', activation='relu')) + nets.append(nn.Dense(n_in, 1, bias_init=Constant(bias_const), activation='sigmoid')) + self.dense = nn.SequentialCell(nets) + else: + raise ValueError("n_hidden cannot be negative!") + + self.squeeze = P.Squeeze(-1) + + def construct(self, x): + y = self.dense(x) + return self.squeeze(y) + +# Modified from: +# https://github.com/andreamad8/Universal-Transformer-Pytorch/blob/master/models/UTransformer.py + + +class ACTWeight(nn.Cell): + """ACTWeight""" + def __init__(self, threshold=0.9): + super().__init__() + self.threshold = threshold + + self.zeros_like = P.ZerosLike() + self.ones_like = P.OnesLike() + + def construct(self, prob, halting_prob): + """construct""" + + # Mask for inputs which have not halted last cy + running = F.cast(halting_prob < 1.0, ms.float32) + + # Add the halting probability for this step to the halting + # probabilities for those input which haven't halted yet + add_prob = prob * running + new_prob = halting_prob + add_prob + mask_run = F.cast(new_prob <= self.threshold, ms.float32) + mask_halt = F.cast(new_prob > self.threshold, ms.float32) + + # Mask of inputs which haven't halted, and didn't halt this step + still_running = mask_run * running + running_prob = halting_prob + prob * still_running + + # Mask of inputs which halted at this step + new_halted = mask_halt * running + + # Compute remainders for the inputs which halted at this step + remainders = new_halted * (1.0 - running_prob) + + # Add the remainders to those inputs which halted at this step + # halting_prob = new_prob + remainders + dp = add_prob + remainders + + # Increment n_updates for all inputs which are still running + # n_updates = n_updates + running + dn = running + + # Compute the weight to be applied to the new state and output + # 0 when the input has already halted + # prob when the input hasn't halted yet + # the remainders when it halted this step + update_weights = prob * still_running + new_halted * remainders + w = F.expand_dims(update_weights, -1) + + return w, dp, dn + + +class Num2Mask(nn.Cell): + """Num2Mask""" + def __init__(self, dim): + super().__init__() + self.range = nn.Range(dim) + ones = P.Ones() + self.ones = ones((dim), ms.int32) + + def construct(self, num): + nmax = num * self.ones + idx = F.ones_like(num) * self.range() + return idx < nmax + + +class Number2FullConnectNeighbors(nn.Cell): + """Number2FullConnectNeighbors""" + def __init__(self, tot_atoms): + super().__init__() + # tot_atoms: A + # tot_neigh: N = A - 1 + tot_neigh = tot_atoms - 1 + arange = nn.Range(tot_atoms) + nrange = nn.Range(tot_neigh) + + self.ones = P.Ones() + self.aones = self.ones((tot_atoms), ms.int32) + self.nones = self.ones((tot_neigh), ms.int32) + + # neighbors for no connection (A*N) + # [[0,0,...,0], + # [1,1,...,1], + # ..........., + # [N,N,...,N]] + self.nnc = F.expand_dims(arange(), -1) * self.nones + # copy of the index range (A*N) + # [[0,1,...,N-1], + # [0,1,...,N-1], + # ..........., + # [0,1,...,N-1]] + crange = self.ones((tot_atoms, 1), ms.int32) * nrange() + # neighbors for full connection (A*N) + # [[1,2,3,...,N], + # [0,2,3,...,N], + # [0,1,3,....N], + # ............., + # [0,1,2,...,N-1]] + self.nfc = crange + F.cast(self.nnc <= crange, ms.int32) + + crange1 = crange + 1 + # the matrix for index range (A*N) + # [[1,2,3,...,N], + # [1,2,3,...,N], + # [2,2,3,....N], + # [3,3,3,....N], + # ............., + # [N,N,N,...,N]] + self.mat_idx = F.select(crange1 > self.nnc, crange1, self.nnc) + + def get_full_neighbors(self): + """get_full_neighbors""" + return F.expand_dims(self.nfc, 0) + + def construct(self, num_atoms): + """construct""" + # broadcast atom numbers to [B*A*N] + # a_i: number of atoms in each molecule + exnum = num_atoms * self.aones + exnum = F.expand_dims(exnum, -1) * self.nones + + # [B,1,1] + exones = self.ones((num_atoms.shape[0], 1, 1), ms.int32) + # broadcast to [B*A*N]: [B,1,1] * [1,A,N] + exnfc = exones * F.expand_dims(self.nfc, 0) + exnnc = exones * F.expand_dims(self.nnc, 0) + exmat = exones * F.expand_dims(self.mat_idx, 0) + + mask = exmat < exnum + + neighbors = F.select(mask, exnfc, exnnc) + + return neighbors, mask + + +class Types2FullConnectNeighbors(nn.Cell): + """Types2FullConnectNeighbors""" + def __init__(self, tot_atoms): + super().__init__() + # tot_atoms: A + # tot_neigh: N = A - 1 + tot_neigh = tot_atoms - 1 + arange = nn.Range(tot_atoms) + nrange = nn.Range(tot_neigh) + + self.ones = P.Ones() + self.aones = self.ones((tot_atoms), ms.int32) + self.nones = self.ones((tot_neigh), ms.int32) + self.eaones = F.expand_dims(self.aones, -1) + + # neighbors for no connection (A*N) + # [[0,0,...,0], + # [1,1,...,1], + # ..........., + # [N,N,...,N]] + self.nnc = F.expand_dims(arange(), -1) * self.nones + + # copy of the index range (A*N) + # [[0,1,...,N-1], + # [0,1,...,N-1], + # ..........., + # [0,1,...,N-1]] + exrange = self.ones((tot_atoms, 1), ms.int32) * nrange() + + # neighbors for full connection (A*N) + # [[1,2,3,...,N], + # [0,2,3,...,N], + # [0,1,3,....N], + # ............., + # [0,1,2,...,N-1]] + self.nfc = exrange + F.cast(self.nnc <= exrange, ms.int32) + + self.ar0 = nn.Range(0, tot_neigh)() + self.ar1 = nn.Range(1, tot_atoms)() + + def get_full_neighbors(self): + """get_full_neighbors""" + return F.expand_dims(self.nfc, 0) + + def construct(self, atom_types): + """construct""" + # [B,1,1] + exones = self.ones((atom_types.shape[0], 1, 1), ms.int32) + # broadcast to [B*A*N]: [B,1,1] * [1,A,N] + exnfc = exones * F.expand_dims(self.nfc, 0) + exnnc = exones * F.expand_dims(self.nnc, 0) + + tmask = F.select( + atom_types > 0, + F.ones_like(atom_types), + F.ones_like(atom_types) * -1) + tmask = F.cast(tmask, ms.float32) + extmask = F.expand_dims(tmask, -1) * self.nones + + mask0 = F.gather(tmask, self.ar0, -1) + mask0 = F.expand_dims(mask0, -2) * self.eaones + mask1 = F.gather(tmask, self.ar1, -1) + mask1 = F.expand_dims(mask1, -2) * self.eaones + + mtmp = F.select(exnfc > exnnc, mask1, mask0) + mask = F.select(extmask > 0, mtmp, F.ones_like(mtmp) * -1) + mask = mask > 0 + + idx = F.select(mask, exnfc, exnnc) + + return idx, mask diff --git a/MindSPONGE/mindsponge/md/cybertron/blocks.py b/MindSPONGE/mindsponge/md/cybertron/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf6716f0ae78190f96a94cefdea508a5335f8c5 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/blocks.py @@ -0,0 +1,193 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""cybertron blocks""" + +from mindspore import nn +from cybertron.activations import get_activation + +__all__ = [ + "Dense", + "MLP", + "Residual", + "PreActDense", + "PreActResidual", +] + + +class Dense(nn.Dense): + def __init__(self, + in_channels, + out_channels, + weight_init='xavier_uniform', + bias_init='zero', + has_bias=True, + activation=None, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + weight_init=weight_init, + bias_init=bias_init, + has_bias=has_bias, + activation=get_activation(activation), + ) + + +class MLP(nn.Cell): + """Multiple layer fully connected perceptron neural network. + + Args: + n_in (int): number of input dimensions. + n_out (int): number of output dimensions. + layer_dims (list of int or int): number hidden layer dimensions. + If an integer, same number of node is used for all hidden layers resulting + in a rectangular network. + If None, the number of neurons is divided by two after each layer starting + n_in resulting in a pyramidal network. + n_layers (int, optional): number of layers. + activation (callable, optional): activation function. All hidden layers would + the same activation function except the output layer that does not apply + any activation function. + + """ + + def __init__(self, n_in, n_out, layer_dims=None, activation=None, weight_init='xavier_uniform', bias_init='zero', + use_last_activation=False,): + super().__init__() + + # get list of number of dimensions in input, hidden & output layers + if layer_dims is None or not layer_dims: + self.mlp = nn.Dense(n_in, n_out, activation=activation) + else: + # assign a Dense layer (with activation function) to each hidden + # layer + nets = [] + indim = n_in + for ldim in layer_dims: + # nets.append(Dense(indim, ldim,activation=activation)) + nets.append( + nn.Dense( + in_channels=indim, + out_channels=ldim, + weight_init=weight_init, + bias_init=bias_init, + has_bias=True, + activation=get_activation(activation), + ) + ) + indim = ldim + + # assign a Dense layer to the output layer + if use_last_activation and activation is not None: + nets.append( + nn.Dense( + in_channels=indim, + out_channels=n_out, + weight_init=weight_init, + bias_init=bias_init, + has_bias=True, + activation=get_activation(activation), + ) + ) + else: + nets.append( + nn.Dense( + in_channels=indim, + out_channels=n_out, + weight_init=weight_init, + bias_init=bias_init, + has_bias=True, + activation=None) + ) + # put all layers together to make the network + self.mlp = nn.SequentialCell(nets) + + def construct(self, x): + """Compute neural network output. + + Args: + inputs (torch.Tensor): network input. + + Returns: + torch.Tensor: network output. + + """ + + y = self.mlp(x) + + return y + + +class Residual(nn.Cell): + """Residual""" + def __init__(self, dim, activation, n_hidden=1): + super().__init__() + + if n_hidden > 0: + hidden_layers = [dim for _ in range(n_hidden)] + self.nonlinear = MLP( + dim, dim, hidden_layers, activation=activation) + else: + self.nonlinear = Dense(dim, dim, activation=activation) + + def construct(self, x): + return x + self.nonlinear(x) + + +class PreActDense(nn.Cell): + def __init__(self, dim_in, dim_out, activation): + super().__init__() + + self.activation = get_activation(activation) + self.dense = Dense(dim_in, dim_out, activation=None) + + def construct(self, x): + x = self.activation(x) + return self.dense(x) + + +class PreActResidual(nn.Cell): + """PreActResidual""" + def __init__(self, dim, activation): + super().__init__() + + self.preact_dense1 = PreActDense(dim, dim, activation) + self.preact_dense2 = PreActDense(dim, dim, activation) + + def construct(self, x): + x1 = self.preact_dense1(x) + x2 = self.preact_dense1(x1) + return x + x2 + + +class SeqPreActResidual(nn.Cell): + def __init__(self, dim, activation, n_res): + super().__init__() + + self.sequential = nn.SequentialCell( + [PreActResidual(dim, activation) for i in range(n_res)] + ) + + def construct(self, x): + return self.sequential(x) diff --git a/MindSPONGE/mindsponge/md/cybertron/cutoff.py b/MindSPONGE/mindsponge/md/cybertron/cutoff.py new file mode 100644 index 0000000000000000000000000000000000000000..2207f0ff9a148ccfa1369d0ab35939f1153f100f --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/cutoff.py @@ -0,0 +1,458 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""cybertron cutoff""" + +import numpy as np +import mindspore as ms +from mindspore import nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +from cybertron.units import units + +__all__ = [ + "CosineCutoff", + "MollifierCutoff", + "HardCutoff", + "SmoothCutoff", + "GaussianCutoff", + "get_cutoff", +] + +_CUTOFF_ALIAS = dict() + + +def _cutoff_register(*aliases): + """Return the alias register.""" + def alias_reg(cls): + name = cls.__name__ + name = name.lower() + if name not in _CUTOFF_ALIAS: + _CUTOFF_ALIAS[name] = cls + + for alias in aliases: + if alias not in _CUTOFF_ALIAS: + _CUTOFF_ALIAS[alias] = cls + + return cls + + return alias_reg + + +class Cutoff(nn.Cell): + """Cutoff""" + def __init__(self, + r_max=units.length(1, 'nm'), + r_min=0, + hyperparam='default', + return_mask=False, + reverse=False + ): + super().__init__() + self.name = 'cutoff' + self.hyperparam = hyperparam + self.r_min = r_min + self.cutoff = r_max + self.return_mask = return_mask + self.reverse = reverse + + +@_cutoff_register('cosine') +class CosineCutoff(Cutoff): + r"""Class of Behler cosine cutoff. + + .. math:: + f(r) = \begin{cases} + 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] + & r < r_\text{cutoff} \\ + 0 & r \geqslant r_\text{cutoff} \\ + \end{cases} + + Args: + cutoff (float, optional): cutoff radius. + + """ + + def __init__(self, + r_max=units.length(1, 'nm'), + r_min='default', + hyperparam='default', + return_mask=False, + reverse=False + ): + super().__init__( + r_max=r_max, + r_min=r_min, + hyperparam=None, + return_mask=return_mask, + reverse=reverse, + ) + + self.name = 'cosine cutoff' + self.pi = Tensor(np.pi, ms.float32) + self.cos = P.Cos() + self.logical_and = P.LogicalAnd() + + def construct(self, distances, neighbor_mask=None): + """Compute cutoff. + + Args: + distances (mindspore.Tensor): values of interatomic distances. + + Returns: + mindspore.Tensor: values of cutoff function. + + """ + # Compute values of cutoff function + + cuts = 0.5 * (self.cos(distances * self.pi / self.cutoff) + 1.0) + if self.reverse: + cuts = 1.0 - cuts + ones = F.ones_like(cuts) + cuts = F.select(distances < cuts, cuts, ones) + if neighbor_mask is None: + mask = distances >= 0 + else: + mask = neighbor_mask + else: + mask = distances < self.cutoff + if neighbor_mask is not None: + mask = self.logical_and(mask, neighbor_mask) + + # Remove contributions beyond the cutoff radius + cutoffs = cuts * mask + + if self.return_mask: + return cutoffs, mask + return cutoffs + + +@_cutoff_register('mollifier') +class MollifierCutoff(Cutoff): + r"""Class for mollifier cutoff scaled to have a value of 1 at :math:`r=0`. + + .. math:: + f(r) = \begin{cases} + \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right) + & r < r_\text{cutoff} \\ + 0 & r \geqslant r_\text{cutoff} \\ + \end{cases} + + Args: + cutoff (float, optional): Cutoff radius. + eps (float, optional): offset added to distances for numerical stability. + + """ + + def __init__(self, + r_max=units.length(1, 'nm'), + r_min='default', + hyperparam='default', + return_mask=False, + reverse=False + ): + super().__init__( + r_min=r_min, + r_max=r_max, + hyperparam=hyperparam, + return_mask=return_mask, + reverse=reverse, + ) + + self.name = "Mollifier cutoff" + + if hyperparam == 'default': + self.eps = units.length(1.0e-8, 'nm') + else: + self.eps = hyperparam + + self.exp = P.Exp() + self.logical_and = P.LogicalAnd() + + def construct(self, distances, neighbor_mask=None): + """Compute cutoff. + + Args: + distances (mindspore.Tensor): values of interatomic distances. + + Returns: + mindspore.Tensor: values of cutoff function. + + """ + + exponent = 1.0 - 1.0 / (1.0 - F.square(distances / self.cutoff)) + cutoffs = self.exp(exponent) + + if self.reverse: + cutoffs = 1. - cutoffs + ones = F.ones_like(cutoffs) + cutoffs = F.select(distances < self.cutoff, cutoffs, ones) + if neighbor_mask is None: + mask = (distances + self.eps) >= 0 + else: + mask = neighbor_mask + else: + mask = (distances + self.eps) < self.cutoff + if neighbor_mask is not None: + mask = self.logical_and(mask, neighbor_mask) + + cutoffs = cutoffs * mask + + return cutoffs, mask + + +@_cutoff_register('hard') +class HardCutoff(Cutoff): + r"""Class of hard cutoff. + + .. math:: + f(r) = \begin{cases} + 1 & r \leqslant r_\text{cutoff} \\ + 0 & r > r_\text{cutoff} \\ + \end{cases} + + Args: + cutoff (float): cutoff radius. + + """ + + def __init__(self, + r_max=units.length(1, 'nm'), + r_min=0, + hyperparam='default', + return_mask=False, + reverse=False + ): + super().__init__( + r_min=r_min, + r_max=r_max, + hyperparam=None, + return_mask=return_mask, + reverse=reverse, + ) + + self.name = "Hard cutoff" + self.logical_and = P.LogicalAnd() + + def construct(self, distances, neighbor_mask=None): + """Compute cutoff. + + Args: + distances (mindspore.Tensor): values of interatomic distances. + + Returns: + mindspore.Tensor: values of cutoff function. + + """ + + if self.reverse: + mask = distances >= self.cutoff + else: + mask = distances < self.cutoff + + if neighbor_mask is not None: + self.logical_and(mask, neighbor_mask) + + if self.return_mask: + return F.cast(mask, distances.dtype), mask + return F.cast(mask, distances.dtype) + + +@_cutoff_register('smooth') +class SmoothCutoff(Cutoff): + r"""Class of smooth cutoff by Ebert, D. S. et al: + [ref] Ebert, D. S.; Musgrave, F. K.; Peachey, D.; Perlin, K.; Worley, S. + Texturing & Modeling: A Procedural Approach; Morgan Kaufmann: 2003 + + .. math:: + r_min < r < r_max: + f(r) = 1.0 - 6 * ( r / r_cutoff ) ^ 5 + + 15 * ( r / r_cutoff ) ^ 4 + - 10 * ( r / r_cutoff ) ^ 3 + r >= r_max: f(r) = 0 + r <= r_min: f(r) = 1 + + reverse: + r_min < r < r_max: + f(r) = 6 * ( r / r_cutoff ) ^ 5 + - 15 * ( r / r_cutoff ) ^ 4 + + 10 * ( r / r_cutoff ) ^ 3 + r >= r_max: f(r) = 1 + r <= r_min: f(r) = 0 + + Args: + d_max (float, optional): the maximum distance (cutoff radius). + d_min (float, optional): the minimum distance + + """ + + def __init__(self, + r_max=units.length(1, 'nm'), + r_min=0, + hyperparam='default', + return_mask=False, + reverse=False + ): + super().__init__( + r_min=r_min, + r_max=r_max, + hyperparam=None, + return_mask=return_mask, + reverse=reverse, + ) + + if self.r_min >= self.cutoff: + raise ValueError( + 'dis_min must be smaller than cutoff at SmmothCutoff') + + self.dis_range = self.cutoff - self.r_min + + self.pow = P.Pow() + self.logical_and = P.LogicalAnd() + + def construct(self, distance, neighbor_mask=None): + """Compute cutoff. + + Args: + distances (mindspore.Tensor or float): values of interatomic distances. + + Returns: + mindspore.Tensor or float: values of cutoff function. + + """ + dd = distance - self.r_min + dd = dd / self.dis_range + cuts = - 6. * self.pow(dd, 5) \ + + 15. * self.pow(dd, 4) \ + - 10. * self.pow(dd, 3) + + if self.reverse: + cutoffs = -cuts + mask_upper = distance < self.cutoff + mask_lower = distance > self.r_min + else: + cutoffs = 1 + cuts + mask_upper = distance > self.r_min + mask_lower = distance < self.cutoff + + if neighbor_mask is not None: + mask_lower = self.logical_and(mask_lower, neighbor_mask) + + zeros = F.zeros_like(distance) + ones = F.ones_like(distance) + + cutoffs = F.select(mask_upper, cutoffs, ones) + cutoffs = F.select(mask_lower, cutoffs, zeros) + + if self.return_mask: + return cutoffs, mask_lower + return cutoffs + + +@_cutoff_register('gaussian') +class GaussianCutoff(Cutoff): + r"""Class of hard cutoff. + + .. math:: + f(r) = \begin{cases} + 1 & r \leqslant r_\text{cutoff} \\ + 0 & r > r_\text{cutoff} \\ + \end{cases} + + Args: + cutoff (float): cutoff radius. + + """ + + def __init__(self, + r_max=units.length(1, 'nm'), + r_min=0, + hyperparam='default', + return_mask=False, + reverse=False + ): + super().__init__( + r_min=r_min, + r_max=r_max, + hyperparam=hyperparam, + return_mask=return_mask, + reverse=reverse, + ) + + if hyperparam == 'default': + self.sigma = units.length(1, 'nm') + else: + self.sigma = hyperparam + + self.sigma2 = self.sigma * self.sigma + + self.exp = P.Exp() + self.logical_and = P.LogicalAnd() + + def construct(self, distance, neighbor_mask=None): + """construct""" + dd = distance - self.cutoff + dd2 = dd * dd + + gauss = self.exp(-0.5 * dd2 / self.sigma2) + + if self.reverse: + cuts = gauss + ones = F.ones_like(cuts) + cuts = F.select(distance < self.cutoff, cuts, ones) + + if neighbor_mask is None: + mask = distance >= 0 + else: + mask = neighbor_mask + else: + cuts = 1. - gauss + mask = distance < self.cutoff + if neighbor_mask is not None: + mask = self.logical_and(mask, neighbor_mask) + + cuts = cuts * mask + + if self.return_mask: + return cuts, mask + return cuts + + +def get_cutoff(obj, r_max=units.length(1, 'nm'), r_min=0, hyperparam='default', return_mask=False, reverse=False): + """get cutoff""" + if obj is None or isinstance(obj, Cutoff): + return obj + if isinstance(obj, str): + if obj not in _CUTOFF_ALIAS.keys(): + raise ValueError( + "The class corresponding to '{}' was not found.".format(obj)) + return _CUTOFF_ALIAS[obj.lower()]( + r_min=r_min, + r_max=r_max, + hyperparam=hyperparam, + return_mask=return_mask, + reverse=reverse, + ) + raise TypeError("Unsupported Cutoff type '{}'.".format(type(obj))) diff --git a/MindSPONGE/mindsponge/md/cybertron/cybertron.py b/MindSPONGE/mindsponge/md/cybertron/cybertron.py new file mode 100644 index 0000000000000000000000000000000000000000..18b9068f4515581eccf1673f67cabd2085389976 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/cybertron.py @@ -0,0 +1,477 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""cybertron""" + +import mindspore as ms +from mindspore import nn +from mindspore import Tensor +from mindspore.ops import functional as F +from mindspore.ops import operations as P + +from cybertron.units import units +from cybertron.base import Types2FullConnectNeighbors +from cybertron.readouts import Readout, LongeRangeReadout +from cybertron.readouts import AtomwiseReadout, GraphReadout +from cybertron.neighbors import Distances + + +class Cybertron(nn.Cell): + """Cybertron: An architecture to perform deep molecular model for molecular modeling. + + Args: + model (nn.Cell): Deep molecular model + dim_output (int): Output dimension of the predictions + unit_dis (str): Unit of input distance + unit_energy (str): Unit of output energy + readout (readouts.Readout): Readout function + + """ + + def __init__(self, model, dim_output=1, unit_dis='nm', unit_energy=None, readout='atomwise', max_atoms_number=0, + atom_types=None, bond_types=None, pbcbox=None, full_connect=False, cut_shape=False,): + super().__init__() + + self.model = model + self.dim_output = dim_output + self.cut_shape = cut_shape + + self.unit_dis = unit_dis + self.unit_energy = unit_energy + + self.dis_scale = units.length_convert_from(unit_dis) + activation = self.model.activation + + self.molsum = P.ReduceSum(keep_dims=True) + + self.atom_mask = None + self.atom_types = None + if atom_types is None: + self.fixed_atoms = False + self.num_atoms = 0 + else: + self.fixed_atoms = True + self.model.set_fixed_atoms(True) + + if len(atom_types.shape) == 1: + self.num_atoms = len(atom_types) + elif len(atom_types.shape) == 2: + self.num_atoms = len(atom_types[0]) + + if self.num_atoms <= 0: + raise ValueError( + "The 'num_atoms' cannot be 0 " + + "'atom_types' is not 'None' in MolCalculator!") + + if not isinstance(atom_types, Tensor): + atom_types = Tensor(atom_types, ms.int32) + + self.atom_types = atom_types + self.atom_mask = F.expand_dims(atom_types, -1) > 0 + if self.atom_mask.all(): + self.atom_mask = None + + atoms_number = F.cast(atom_types > 0, ms.float32) + self.atoms_number = self.molsum(atoms_number, -1) + + self.pbcbox = None + self.use_fixed_box = False + if pbcbox is not None: + if isinstance(pbcbox, (list, tuple)): + pbcbox = Tensor(pbcbox, ms.float32) + if not isinstance(pbcbox, Tensor): + raise TypeError( + "Unsupported pbcbox type '{}'.".format( + type(pbcbox))) + if len(pbcbox.shape) == 1: + pbcbox = F.expand_dims(pbcbox, 0) + if len(pbcbox.shape) != 2: + raise ValueError( + "The length of shape of pbcbox must be 1 or 2") + if pbcbox.shape[-1] != 3: + raise ValueError("The last dimension of pbcbox must be 3") + if pbcbox.shape[0] != 1: + raise ValueError("The first dimension of pbcbox must be 1") + self.pbcbox = pbcbox + self.use_fixed_box = True + + self.use_bonds = self.model.use_bonds + self.fixed_bonds = False + self.bonds = None + if bond_types is not None: + self.bonds = bond_types + self.bond_mask = (bond_types > 0) + self.fixed_bonds = True + + self.cutoff = self.model.cutoff + + self.use_distances = self.model.use_distances + + self.full_connect = full_connect + + if self.fixed_bonds and (not self.use_distances): + raise ValueError( + '"fixed_bonds" cannot be used without using distances') + + self.neighbors = None + self.mask = None + self.fc_neighbors = None + if self.full_connect: + if self.fixed_atoms: + self.fc_neighbors = Types2FullConnectNeighbors(self.num_atoms) + self.neighbors = self.fc_neighbors.get_full_neighbors() + else: + if max_atoms_number <= 0: + raise ValueError( + "The 'max_atoms_num' cannot be 0 " + + "when the 'full_connect' flag is 'True' and " + + "'atom_types' is 'None' in MolCalculator!") + self.fc_neighbors = Types2FullConnectNeighbors( + max_atoms_number) + self.max_atoms_number = max_atoms_number + + if self.fixed_atoms and self.full_connect: + fixed_neigh = True + self.distances = Distances(True, long_dis=self.cutoff * 10) + self.model.set_fixed_neighbors(True) + else: + fixed_neigh = False + self.distances = Distances(False, long_dis=self.cutoff * 10) + self.fixed_neigh = fixed_neigh + + self.multi_readouts = False + self.num_readout = 1 + + dim_feature = self.model.dim_feature + n_interactions = self.model.n_interactions + + if isinstance(readout, (tuple, list)): + self.num_readout = len(readout) + if self.num_readout == 1: + readout = readout[0] + else: + self.multi_readouts = True + + if self.multi_readouts: + readouts = [] + for i in range(self.num_readout): + readouts.append(self._get_readout(readout[i], + n_in=dim_feature, + n_out=dim_output, + activation=activation, + unit_energy=unit_energy, + )) + self.readout = nn.CellList(readouts) + else: + self.readout = self._get_readout(readout, + n_in=dim_feature, + n_out=dim_output, + activation=activation, + unit_energy=unit_energy, + ) + + self.output_scale = 1 + self.calc_far = False + read_all_interactions = False + self.dim_output = 0 + if self.multi_readouts: + read_all_interactions = False + self.output_scale = [] + for i in range(self.num_readout): + self.dim_output += self.readout[i].total_out + if unit_energy is not None and self.readout[i].output_is_energy: + unit_energy = units.check_energy_unit(unit_energy) + self.output_scale.append( + units.energy_convert_to(unit_energy)) + else: + self.output_scale.append(1) + + if isinstance(self.readout[i], LongeRangeReadout): + self.calc_far = True + self.readout[i].set_fixed_neighbors(fixed_neigh) + if self.readout[i].read_all_interactions: + read_all_interactions = False + if self.readout[i].interaction_decoders is not None and\ + self.readout[i].n_interactions != n_interactions: + raise ValueError( + 'The n_interactions in model readouts are not equal') + if self.readout[i].n_in != dim_feature: + raise ValueError( + 'n_in in readouts is not equal to dim_feature') + else: + self.dim_output = self.readout.total_out + + if unit_energy is not None and self.readout.output_is_energy: + unit_energy = units.check_energy_unit(unit_energy) + self.output_scale = units.energy_convert_to(unit_energy) + else: + self.output_scale = 1 + + if isinstance(self.readout, LongeRangeReadout): + self.calc_far = True + self.readout.set_fixed_neighbors(fixed_neigh) + + if self.readout.read_all_interactions: + read_all_interactions = True + if self.readout.interaction_decoders is not None and self.readout.n_interactions != n_interactions: + raise ValueError( + 'The n_interactions in model readouts are not equal') + + if self.readout.n_in != dim_feature: + raise ValueError( + 'n_in in readouts is not equal to dim_feature') + + self.unit_energy = unit_energy + + self.model.read_all_interactions = read_all_interactions + + self.ones = P.Ones() + self.reduceany = P.ReduceAny(keep_dims=True) + self.reducesum = P.ReduceSum(keep_dims=False) + self.reducemax = P.ReduceMax() + self.reducemean = P.ReduceMean(keep_dims=False) + self.concat = P.Concat(-1) + + def _get_readout(self, readout, n_in, n_out, activation, unit_energy,): + """get readout""" + if readout is None or isinstance(readout, Readout): + return readout + if isinstance(readout, str): + if readout.lower() == 'atom' or readout.lower() == 'atomwise': + readout = AtomwiseReadout + elif readout.lower() == 'graph' or readout.lower() == 'set2set': + readout = GraphReadout + else: + raise ValueError("Unsupported Readout type" + readout.lower()) + + return readout( + n_in=n_in, + n_out=n_out, + activation=activation, + unit_energy=unit_energy, + ) + + raise TypeError("Unsupported Readout type '{}'.".format(type(readout))) + + def print_info(self): + """print info""" + print("================================================================================") + print("Cybertron Engine, Ride-on!") + print('---with input distance unit: ' + self.unit_dis) + print('---with input distance unit: ' + self.unit_dis) + if self.fixed_atoms: + print('---with fixed atoms: ' + str(self.atom_types[0])) + if self.full_connect: + print('---using full connected neighbors') + if self.use_bonds and self.fixed_bonds: + print('---using fixed bond connection:') + for b in self.bonds[0]: + print('------' + str(b.asnumpy())) + print('---with fixed bond mask:') + for m in self.bond_mask[0]: + print('------' + str(m.asnumpy())) + self.model.print_info() + + if self.multi_readouts: + print("---with multiple readouts: ") + for i in range(self.num_readout): + print("---" + str(i + 1) + + (". " + self.readout[i].name + " readout")) + else: + print("---with readout type: " + self.readout.name) + self.readout.print_info() + + if self.unit_energy is not None: + print("---with output units: " + str(self.unit_energy)) + print("---with output scale: " + str(self.output_scale)) + print("---with total output dimension: " + str(self.dim_output)) + print("================================================================================") + + def construct(self, + positions=None, + atom_types=None, + pbcbox=None, + neighbors=None, + neighbor_mask=None, + bonds=None, + bond_mask=None, + ): + """Compute the properties of the molecules. + + Args: + positions (mindspore.Tensor[float], [B, A, 3]): Cartesian coordinates for each atom. + atom_types (mindspore.Tensor[int], [B, A]): Types (nuclear charge) of input atoms. + If the attribute "self.atom_types" have been set and + atom_types is not given here, + atom_types = self.atom_types + neighbors (mindspore.Tensor[int], [B, A, N]): Indices of other near neighbor atoms around a atom + neighbor_mask (mindspore.Tensor[bool], [B, A, N]): Mask for neighbors + bonds (mindspore.Tensor[int], [B, A, N]): Types (ID) of bond connected with two atoms + bond_mask (mindspore.Tensor[bool], [B, A, N]): Mask for bonds + + B: Batch size, usually the number of input molecules or frames + A: Number of input atoms, usually the number of atoms in one molecule or frame + N: Number of other nearest neighbor atoms around a atom + O: Output dimension of the predicted properties + + Returns: + properties mindspore.Tensor[float], [B,A,O]: prediction for the properties of the molecules + + """ + + atom_mask = None + atoms_number = None + if atom_types is None: + if self.fixed_atoms: + atom_types = self.atom_types + atom_mask = self.atom_mask + atoms_number = self.atoms_number + if self.full_connect: + neighbors = self.neighbors + neighbor_mask = None + else: + # raise ValueError('atom_types is miss') + return None + else: + atom_mask = F.expand_dims(atom_types, -1) > 0 + atoms_number = F.cast(atom_types > 0, ms.float32) + atoms_number = self.molsum(atoms_number, -1) + + if pbcbox is None and self.use_fixed_box: + pbcbox = self.pbcbox + + if self.use_bonds: + if bonds is None: + if self.fixed_bonds: + exones = self.ones((positions.shape[0], 1, 1), ms.int32) + bonds = exones * self.bonds + bond_mask = exones * self.bond_mask + else: + # raise ValueError('bonds is miss') + return None + if bond_mask is None: + bond_mask = (bonds > 0) + + if neighbors is None: + if self.full_connect: + neighbors, neighbor_mask = self.fc_neighbors(atom_types) + if self.cut_shape: + atypes = F.cast(atom_types > 0, positions.dtype) + anum = self.reducesum(atypes, -1) + nmax = self.reducemax(anum) + nmax = F.cast(nmax, ms.int32) + nmax0 = int(nmax.asnumpy()) + nmax1 = nmax0 - 1 + + atom_types = atom_types[:, :nmax0] + positions = positions[:, :nmax0, :] + neighbors = neighbors[:, :nmax0, :nmax1] + neighbor_mask = neighbor_mask[:, :nmax0, :nmax1] + else: + # raise ValueError('neighbors is miss') + return None + + if self.use_distances: + r_ij = self.distances( + positions, + neighbors, + neighbor_mask, + pbcbox) * self.dis_scale + else: + r_ij = 1 + neighbor_mask = bond_mask + + x, xlist = self.model(r_ij, atom_types, atom_mask, + neighbors, neighbor_mask, bonds, bond_mask) + + if self.readout is None: + return x + + if self.multi_readouts: + ytuple = () + for i in range(self.num_readout): + yi = self.readout[i]( + x, + xlist, + atom_types, + atom_mask, + atoms_number) + if self.unit_energy is not None: + yi = yi * self.output_scale[i] + ytuple = ytuple + (yi,) + y = self.concat(ytuple) + else: + y = self.readout( + x, + xlist, + atom_types, + atom_mask, + atoms_number) + if self.unit_energy is not None: + y = y * self.output_scale + + return y + + +class CybertronFF(Cybertron): + """CybertronFF""" + def __init__(self, model, dim_output=1, unit_dis='nm', unit_energy=None, readout='atomwise', max_atoms_number=0, + atom_types=None, bond_types=None, full_connect=False, pbcbox=None, cut_shape=False,): + super().__init__( + model=model, + dim_output=dim_output, + unit_dis=unit_dis, + unit_energy=unit_energy, + readout=readout, + max_atoms_number=max_atoms_number, + atom_types=atom_types, + bond_types=bond_types, + full_connect=full_connect, + pbcbox=pbcbox, + cut_shape=cut_shape, + ) + + def construct(self, + positions=None, + atom_types=None, + pbcbox=None, + neighbors=None, + neighbor_mask=None, + bonds=None, + bond_mask=None + ): + if self.full_connect and self.atom_types is not None: + atom_types = self.atom_types + + if self.use_fixed_box: + pbcbox = self.pbcbox + + return super().construct( + positions=positions, + atom_types=atom_types, + pbcbox=pbcbox, + neighbors=neighbors, + neighbor_mask=neighbor_mask, + bonds=bonds, + bond_mask=bond_mask + ) diff --git a/MindSPONGE/mindsponge/md/cybertron/decoders.py b/MindSPONGE/mindsponge/md/cybertron/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..abc70ded2022aa363e1eccba87dcd2bca43d87f1 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/decoders.py @@ -0,0 +1,146 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""decoders""" + +from mindspore import nn + +from cybertron.blocks import MLP, Dense +from cybertron.blocks import PreActResidual +from cybertron.blocks import SeqPreActResidual +from cybertron.blocks import PreActDense + +__all__ = [ + "Decoder", + "get_decoder", + "SimpleDecoder", + "ResidualOutputBlock", +] + +_DECODER_ALIAS = dict() + + +def _decoder_register(*aliases): + """Return the alias register.""" + def alias_reg(cls): + name = cls.__name__ + name = name.lower() + if name not in _DECODER_ALIAS: + _DECODER_ALIAS[name] = cls + + for alias in aliases: + if alias not in _DECODER_ALIAS: + _DECODER_ALIAS[alias] = cls + return cls + return alias_reg + + +class Decoder(nn.Cell): + """Decoder""" + def __init__(self, n_in, n_out=1, activation=None, n_layers=1, output=None): + super().__init__() + + self.name = 'decoder' + self.n_in = n_in + self.n_out = n_out + self.n_layers = n_layers + self.output = output + self.activation = activation + + def construct(self, x): + return self.output(x) + + +@_decoder_register('halve') +class SimpleDecoder(Decoder): + """SimpleDecoder""" + def __init__(self, n_in, n_out, activation, n_layers=1,): + super().__init__( + n_in=n_in, + n_out=n_out, + activation=activation, + n_layers=n_layers, + ) + + self.name = 'halve' + + if n_layers > 0: + n_hiddens = [] + dim = n_in + for _ in range(n_layers): + dim = dim // 2 + if dim < n_out: + raise ValueError( + "The dimension of hidden layer is smaller than output dimension") + n_hiddens.append(dim) + self.output = MLP(n_in, n_out, n_hiddens, activation=activation) + else: + self.output = Dense(n_in, n_out, activation=activation) + + def __str__(self): + return 'halve' + + +@_decoder_register('residual') +class ResidualOutputBlock(Decoder): + """ResidualOutputBlock""" + def __init__(self, n_in, n_out, activation, n_layers=1,): + super().__init__( + n_in=n_in, + n_out=n_out, + activation=activation, + n_layers=n_layers, + ) + + self.name = 'residual' + + if n_layers == 1: + output_residual = PreActResidual(n_in, activation=activation) + else: + output_residual = SeqPreActResidual( + n_in, activation=activation, n_res=n_layers) + + self.output = nn.SequentialCell([ + output_residual, + PreActDense(n_in, n_out, activation=activation), + ]) + + def __str__(self): + return 'residual' + + +def get_decoder(obj, n_in, n_out, activation=None, n_layers=1,): + """get_decoder""" + if obj is None or isinstance(obj, Decoder): + return obj + if isinstance(obj, str): + if obj.lower() not in _DECODER_ALIAS.keys(): + raise ValueError( + "The class corresponding to '{}' was not found.".format(obj)) + return _DECODER_ALIAS[obj.lower()]( + n_in=n_in, + n_out=n_out, + activation=activation, + n_layers=n_layers, + ) + raise TypeError("Unsupported init type '{}'.".format(type(obj))) diff --git a/MindSPONGE/mindsponge/md/cybertron/interactions.py b/MindSPONGE/mindsponge/md/cybertron/interactions.py new file mode 100644 index 0000000000000000000000000000000000000000..f43da5994e8b2fae95263da12550f98a1be1f051 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/interactions.py @@ -0,0 +1,463 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""interactions""" + +import mindspore as ms +from mindspore import nn +from mindspore import Tensor +from mindspore import Parameter +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.initializer import initializer +from mindspore.common.initializer import Normal + +from cybertron.blocks import Dense, MLP +from cybertron.blocks import PreActDense +from cybertron.blocks import SeqPreActResidual +from cybertron.neighbors import GatherNeighbors +from cybertron.base import Aggregate, CFconv +from cybertron.base import PositionalEmbedding +from cybertron.base import MultiheadAttention +from cybertron.base import Pondering, ACTWeight +from cybertron.base import FeedForward +from cybertron.activations import get_activation + +__all__ = [ + "Interaction", + "SchNetInteraction", + "PhysNetModule", + "NeuralInteractionUnit", +] + + +class Interaction(nn.Cell): + """Interaction""" + def __init__(self, + gather_dim, + fixed_neigh, + activation=None, + use_distances=True, + use_bonds=False + ): + super().__init__() + + self.name = 'Interaction' + self.fixed_neigh = fixed_neigh + self.use_bonds = use_bonds + self.use_distances = use_distances + self.activation = activation + self.gather_neighbors = GatherNeighbors(gather_dim, fixed_neigh) + + def set_fixed_neighbors(self, flag=True): + self.fixed_neigh = flag + self.gather_neighbors.fixed_neigh = flag + + def _output_block(self, x): + return x + + +class SchNetInteraction(Interaction): + r"""Continuous-filter convolution block used in SchNet module. + + Args: + dim_feature (int): number of input atomic vector dimensions. + dim_filter (int): dimensions of filter network. + cfconv_module (nn.Cell): the algorithm to calcaulte continuous-filter + convoluations. + cutoff_network (nn.Cell, optional): if None, no cut off function is used. + activation (callable, optional): if None, no activation function is used. + normalize_filter (bool, optional): If True, normalize filter to the number + of neighbors when aggregating. + axis (int, optional): axis over which convolution should be applied. + + """ + + def __init__(self, dim_feature, num_rbf, dim_filter, activation='swish', fixed_neigh=False, + normalize_filter=False,): + super().__init__( + gather_dim=dim_filter, + fixed_neigh=fixed_neigh, + activation=activation, + use_bonds=False, + ) + + self.dim_filter = dim_filter + self.activation = activation + + self.name = 'SchNet Interaction Layer' + self.atomwise_bc = Dense(dim_feature, dim_filter) + self.atomwise_ac = MLP( + dim_filter, dim_feature, [ + dim_feature,], activation=activation, use_last_activation=False) + + self.cfconv = CFconv(num_rbf, dim_filter, activation) + self.agg = Aggregate(axis=-2, mean=normalize_filter) + + def print_info(self): + print('---------with activation function: ' + str(self.activation)) + if self.use_distances and self.use_bonds: + print('---------with edge type: distances and bonds') + else: + print('---------with edge type: ' + + ('distance' if self.use_distances else 'bonds')) + + print('---------with dimension for filter network: ' + str(self.dim_filter)) + + def construct(self, x, f_ij, c_ij, neighbors, mask=None): + """Compute convolution block. + + Args: + x (ms.Tensor[float]): input representation/embedding of atomic environments + with (N_b, N_a, n_in) shape. + rbf (ms.Tensor[float]): interatomic distances of (N_b, N_a, N_nbh) shape. + neighbors (ms.Tensor[int]): indices of neighbors of (N_b, N_a, N_nbh) shape. + mask (ms.Tensor[bool]): mask to filter out non-existing neighbors + introduced via padding. + + Returns: + ms.Tensor: block output with (N_b, N_a, n_out) shape. + + """ + + ax = self.atomwise_bc(x) + xij = self.gather_neighbors(ax, neighbors) + + # CFconv: pass expanded interactomic distances through filter block + y = self.cfconv(xij, f_ij, c_ij) + # element-wise multiplication, aggregating and Dense layer + y = self.agg(y, mask) + + v = self.atomwise_ac(y) + + x_new = x + v + + return x_new + + +class PhysNetModule(Interaction): + r"""Continuous-filter convolution block used in SchNet module. + + Args: + dim_feature (int): number of input atomic vector dimensions. + activation (callable, optional): if None, no activation function is used. + normalize_filter (bool, optional): If True, normalize filter to the number + of neighbors when aggregating. + axis (int, optional): axis over which convolution should be applied. + + """ + + def __init__(self, num_rbf, dim_feature, activation='swish', fixed_neigh=False, n_inter_residual=3, + n_outer_residual=2,): + super().__init__( + gather_dim=dim_feature, + fixed_neigh=fixed_neigh, + activation=activation, + use_bonds=False, + ) + + self.name = 'PhysNet Module Layer' + + self.xi_dense = Dense(dim_feature, dim_feature, activation=activation) + self.xij_dense = Dense(dim_feature, dim_feature, activation=activation) + self.fij_dense = Dense( + num_rbf, + dim_feature, + has_bias=False, + activation=None) + + self.gating_vector = Parameter( + initializer(Normal(1.0), [dim_feature,]), name="gating_vector") + + self.inter_residual = SeqPreActResidual( + dim_feature, activation=activation, n_res=n_inter_residual) + self.inter_dense = PreActDense( + dim_feature, dim_feature, activation=activation) + self.outer_residual = SeqPreActResidual( + dim_feature, activation=activation, n_res=n_outer_residual) + + self.activation = get_activation(activation) + + self.reducesum = P.ReduceSum() + + def print_info(self): + print('---------with activation function: ' + str(self.activation)) + if self.use_distances and self.use_bonds: + print('---------with edge type: distances and bonds') + else: + print('---------with edge type: ' + + ('distance' if self.use_distances else 'bonds')) + + def _interaction_block(self, x, f_ij, c_ij, neighbors, mask): + """_interaction_block""" + + xi = self.activation(x) + xij = self.gather_neighbors(xi, neighbors) + + ux = self.gating_vector * x + + dxi = self.xi_dense(xi) + dxij = self.xij_dense(xij) + g_gij = self.fij_dense(f_ij * F.expand_dims(c_ij, -1)) + + side = g_gij * dxij + if mask is not None: + side = side * F.expand_dims(mask, -1) + v = dxi + self.reducesum(side, -2) + + v1 = self.inter_residual(v) + v1 = self.inter_dense(v1) + return ux + v1 + + def construct(self, x, f_ij, c_ij, neighbors, mask=None): + """Compute convolution block. + + Args: + x (ms.Tensor[float]): input representation/embedding of atomic environments + with (N_b, N_a, n_in) shape. + rbf (ms.Tensor[float]): interatomic distances of (N_b, N_a, N_nbh) shape. + neighbors (ms.Tensor[int]): indices of neighbors of (N_b, N_a, N_nbh) shape. + mask (ms.Tensor[bool]): mask to filter out non-existing neighbors + introduced via padding. + + Returns: + ms.Tensor: block output with (N_b, N_a, n_out) shape. + + """ + + x1 = self._interaction_block(x, f_ij, c_ij, neighbors, mask) + xnew = self.outer_residual(x1) + + return xnew + + +class NeuralInteractionUnit(Interaction): + r"""Continuous-filter convolution block used in SchNet module. + + Args: + dim_feature (int): dimensions of feature space. + cfconv_module (nn.Cell): the algorithm to calcaulte continuous-filter + convoluations. + cutoff_network (nn.Cell, optional): if None, no cut off function is used. + activation (callable, optional): if None, no activation function is used. + normalize_filter (bool, optional): If True, normalize filter to the number + of neighbors when aggregating. + """ + + def __init__(self, dim_feature, num_rbf, n_heads=8, activation='swish', max_cycles=10, time_embedding=0, + use_pondering=True, fixed_cycles=False, use_distances=True, use_dis_filter=True, use_bonds=False, + use_bond_filter=False, act_threshold=0.9, fixed_neigh=False, use_feed_forward=False,): + super().__init__( + gather_dim=dim_feature, + fixed_neigh=fixed_neigh, + activation=activation, + use_distances=use_distances, + use_bonds=use_bonds + ) + if dim_feature % n_heads != 0: + raise ValueError('The term "dim_feature" cannot be divisible ' + + 'by the term "n_heads" in AirNetIneteraction! ') + + self.name = 'Neural Interaction Unit' + + self.n_heads = n_heads + self.max_cycles = max_cycles + self.dim_feature = dim_feature + self.num_rbf = num_rbf + self.time_embedding = time_embedding + + if fixed_cycles: + self.flexable_cycels = False + else: + self.flexable_cycels = True + + self.use_dis_filter = use_dis_filter + if self.use_dis_filter: + self.dis_filter = Dense( + num_rbf, + dim_feature, + has_bias=True, + activation=None) + else: + self.dis_filter = None + + self.bond_filter = None + self.use_bond_filter = use_bond_filter + if self.use_bond_filter: + self.bond_filter = Dense( + dim_feature, + dim_feature, + has_bias=False, + activation=None) + + self.positional_embedding = PositionalEmbedding( + dim_feature, self.use_distances, self.use_bonds) + self.multi_head_attention = MultiheadAttention( + dim_feature, n_heads, dim_tensor=4) + + self.use_feed_forward = use_feed_forward + self.feed_forward = None + if self.use_feed_forward: + self.feed_forward = FeedForward(dim_feature, activation) + + self.act_threshold = act_threshold + self.act_epsilon = 1.0 - act_threshold + + self.use_pondering = use_pondering + self.pondering = None + self.act_weight = None + if self.max_cycles > 1: + if self.use_pondering: + self.pondering = Pondering(dim_feature * 3, bias_const=3) + self.act_weight = ACTWeight(self.act_threshold) + else: + if self.flexable_cycels: + raise ValueError( + 'The term "fixed_cycles" must be True ' + + 'when the pondering network is None in AirNetIneteraction! ') + self.fixed_weight = Tensor(1.0 / max_cycles, ms.float32) + + self.max = P.Maximum() + self.min = P.Minimum() + self.concat = P.Concat(-1) + self.reducesum = P.ReduceSum() + self.squeeze = P.Squeeze(-2) + self.ones_like = P.OnesLike() + self.zeros_like = P.ZerosLike() + self.zeros = P.Zeros() + + def print_info(self): + """print info""" + print('---------with activation function: ' + str(self.activation)) + if self.use_distances and self.use_bonds: + print('---------with edge type: distances and bonds') + else: + print('---------with edge type: ' + + ('distance' if self.use_distances else 'bonds')) + + if self.use_distances: + print('---------with filter for distances: ' + + ('yes' if self.use_dis_filter else 'no')) + + if self.use_bonds: + print('---------with filter for bonds: ' + + ('yes' if self.use_bond_filter else 'no')) + + print('---------with multi-haeds: ' + str(self.n_heads)) + print('---------with feed forward: ' + + ('yes' if self.use_feed_forward else 'no')) + if self.max_cycles > 1: + print('---------using adaptive computation time with threshold: ' + + str(self.act_threshold)) + + def _transformer_encoder(self, x, neighbors, g_ii=1, g_ij=1, b_ii=0, b_ij=0, c_ij=None, t=0, mask=None): + """_transformer_encoder""" + + xij = self.gather_neighbors(x, neighbors) + q, k, v = self.positional_embedding( + x, xij, g_ii, g_ij, b_ii, b_ij, c_ij, t) + v = self.multi_head_attention(q, k, v, mask=mask, cutoff=c_ij) + v = self.squeeze(v) + + if self.use_feed_forward: + return self.feed_forward(x + v) + return x + v + + def construct(self, x, e, f_ii, f_ij, b_ii, b_ij, c_ij, neighbors, mask=None): + """Compute convolution block. + + Args: + x (ms.Tensor[float]): input representation/embedding of atomic environments + with (N_b, N_a, n_in) shape. + r_ij (ms.Tensor[float]): interatomic distances of (N_b, N_a, N_nbh) shape. + neighbors (ms.Tensor[int]): indices of neighbors of (N_b, N_a, N_nbh) shape. + mask (ms.Tensor[bool]): mask to filter out non-existing neighbors + introduced via padding. + + Returns: + ms.Tensor: block output with (N_b, N_a, n_out) shape. + + """ + + if self.use_distances and self.use_dis_filter: + g_ii = self.dis_filter(f_ii) + g_ij = self.dis_filter(f_ij) + else: + g_ii = f_ii + g_ij = f_ij + + if self.use_bond_filter: + b_ii = self.bond_filter(b_ii) + b_ij = self.bond_filter(b_ij) + + if self.max_cycles == 1: + t = self.time_embedding[0] + x0 = self._transformer_encoder( + x, neighbors, g_ii, g_ij, b_ii, b_ij, c_ij, t, mask) + + else: + xx = x + x0 = self.zeros_like(x) + + halting_prob = self.zeros((x.shape[0], x.shape[1]), ms.float32) + n_updates = self.zeros((x.shape[0], x.shape[1]), ms.float32) + + broad_zeros = self.zeros_like(e) + + if self.flexable_cycels: + cycle = self.zeros((), ms.int32) + while((halting_prob < self.act_threshold).any() and (cycle < self.max_cycles)): + t = self.time_embedding[cycle] + vt = broad_zeros + t + + xp = self.concat((xx, e, vt)) + p = self.pondering(xp) + w, dp, dn = self.act_weight(p, halting_prob) + halting_prob = halting_prob + dp + n_updates = n_updates + dn + + xx = self._transformer_encoder( + xx, neighbors, g_ii, g_ij, b_ii, b_ij, c_ij, t, mask) + + cycle = cycle + 1 + + x0 = xx * w + x0 * (1.0 - w) + else: + for cycle in range(self.max_cycles): + t = self.time_embedding[cycle] + vt = broad_zeros + t + + xp = self.concat((xx, e, vt)) + p = self.pondering(xp) + w, dp, dn = self.act_weight(p, halting_prob) + halting_prob = halting_prob + dp + n_updates = n_updates + dn + + xx = self._transformer_encoder( + xx, neighbors, g_ii, g_ij, b_ii, b_ij, c_ij, t, mask) + + cycle = cycle + 1 + + x0 = xx * w + x0 * (1.0 - w) + + return x0 diff --git a/MindSPONGE/mindsponge/md/cybertron/mdnn.py b/MindSPONGE/mindsponge/md/cybertron/mdnn.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4d6df112b156e24c97ed89cdabab287607c4b6 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/mdnn.py @@ -0,0 +1,76 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mdnn class""" + +import numpy as np +from mindspore import nn, Tensor +from mindspore.ops import operations as P +from mindspore.common.parameter import Parameter +import mindspore.common.dtype as mstype + + +__all__ = [ + 'Mdnn', + 'TransCrdToCV' +] + + +class Mdnn(nn.Cell): + """Mdnn""" + + def __init__(self, dim=258, dr=0.5): + super(Mdnn, self).__init__() + self.dim = dim + self.dr = dr # dropout_ratio + self.fc1 = nn.Dense(dim, 512) + self.fc2 = nn.Dense(512, 512) + self.fc3 = nn.Dense(512, 512) + self.fc4 = nn.Dense(512, 129) + self.tanh = nn.Tanh() + + def construct(self, x): + """construct""" + x = self.tanh(self.fc1(x)) + x = self.tanh(self.fc2(x)) + x = self.tanh(self.fc3(x)) + x = self.fc4(x) + return x + + +class TransCrdToCV(nn.Cell): + """TransCrdToCV""" + + def __init__(self, simulation): + super(TransCrdToCV, self).__init__() + self.atom_numbers = simulation.atom_numbers + self.transfercrd = P.TransferCrd(0, 129, 129, self.atom_numbers) + self.box = Tensor(simulation.box_length) + self.radial = Parameter(Tensor(np.zeros([129,]), mstype.float32)) + self.angular = Parameter(Tensor(np.zeros([129,]), mstype.float32)) + self.output = Parameter(Tensor(np.zeros([1, 258]), mstype.float32)) + self.charge = simulation.charge + + def updatecharge(self, t_charge): + """update charge in simulation""" + self.charge[:129] = t_charge[0] * 18.2223 + return self.charge + + def construct(self, crd, last_crd): + """construct""" + self.radial, self.angular, _, _ = self.transfercrd( + crd, last_crd, self.box) + self.output = P.Concat()((self.radial, self.angular)) + self.output = P.ExpandDims()(self.output, 0) + return self.output diff --git a/MindSPONGE/mindsponge/md/cybertron/meta_dynamics.py b/MindSPONGE/mindsponge/md/cybertron/meta_dynamics.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d26bd8e8ae9e1f415750d9ee11df1b377405f9 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/meta_dynamics.py @@ -0,0 +1,106 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""meta_dynamics""" + +import numpy as np +from mindspore import nn +import mindspore as ms +from mindspore import ops +from mindspore import Tensor +from mindspore import Parameter +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import dtype as mstype + +from .units import units + +__all__ = [ + 'Bias' +] + + +class Bias(nn.Cell): + """Bias""" + def __init__(self, hills, smin=0.15, smax=0.25, ds=0.0005, + omega=0.2, sigma=0.0002, dt=0.001, + t=300, alpha=0.5, gamma=6, + wall_potential=9e08, + kappa=4, + upper_bound=10, + lower_bound=190, + factor=1): + super(Bias, self).__init__() + self.smin = smin + self.smax = smax + self.ds = ds + self.dt = dt + self.alpha = alpha + self.grid_num = 200 + self.wall_potential = Tensor([9e08], mstype.float32) + self.wall = np.zeros(self.grid_num, dtype=np.float32) + for i in range(20): + self.wall[i] = wall_potential + self.wall[-i - 1] = wall_potential + self.wall_num = Tensor([5], mstype.float32) + self.temperature = t + self.kb = units.boltzmann() + self.kbt = self.kb * self.temperature + self.beta = 1.0 / self.kbt + self.gamma = gamma + self.wt_factor = -1.0 / (self.gamma - 1.0) * self.beta + self.kappa = kappa + self.upper_bound = upper_bound + self.lower_bound = lower_bound + self.wall_factor = factor + self.wall = Tensor(self.wall, mstype.float32) + self.hill_matrix = Tensor([[abs(j - i) for j in range(self.grid_num)] + for i in range(self.grid_num)], mstype.int32) + self.sum = ops.ReduceSum() + if hills is None: + self.hills = Parameter(Tensor([0], mstype.float32)) + else: + self.hills = Parameter(hills) + self.sqrt2 = F.sqrt(Tensor(2.0, ms.float32)) + self.omega = omega + self.sigma = Tensor(sigma, mstype.float32) + self.exp = ops.Exp() + self.add = ops.Add() + self.norm = nn.Norm(axis=0) + self.square = ops.Square() + self.squeeze = ops.Squeeze() + self.zeros = ops.Zeros() + self.ones = ops.Ones() + self.cast = ops.Cast() + self.keep_sum = P.ReduceSum(keep_dims=True) + self.cv_list = Tensor([i * ds + self.smin[0].asnumpy() for i in range(self.grid_num)], dtype=mstype.float32) + + def get_cv(self, r): + """get_cv""" + return self.norm(self.add(r[11], -r[14])) + + def construct(self, r): + """construct""" + cv = self.get_cv(r) + cv_index = self.cast((cv - self.smin) / self.ds, mstype.int32) + gaussian = self.omega * \ + self.exp(-self.square(cv - self.cv_list) / 2 / self.square(self.sigma)) + bias = self.sum(self.dt * self.hills * gaussian) * (cv_index < + self.upper_bound) * (cv_index >= self.lower_bound) + # Upper Wall + bias += self.kappa * ((cv - self.upper_bound * self.ds) / + self.wall_factor) ** 2 * (cv_index >= self.upper_bound) + bias += self.kappa * ((self.upper_bound * self.ds - cv) / + self.wall_factor) ** 2 * (cv_index < self.lower_bound) + return bias diff --git a/MindSPONGE/mindsponge/md/cybertron/models.py b/MindSPONGE/mindsponge/md/cybertron/models.py new file mode 100644 index 0000000000000000000000000000000000000000..0fee908ef79b2ec25d3f60ea154e0280e5594ea3 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/models.py @@ -0,0 +1,894 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""models""" + +import numpy as np +import mindspore as ms +from mindspore import nn +from mindspore import Tensor +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.common.initializer import Normal + +from cybertron.units import units +from cybertron.blocks import Dense, Residual +from cybertron.interactions import SchNetInteraction +from cybertron.interactions import PhysNetModule +from cybertron.interactions import NeuralInteractionUnit +from cybertron.base import ResFilter, GraphNorm +from cybertron.cutoff import get_cutoff +from cybertron.rbf import GaussianSmearing, LogGaussianDistribution +from cybertron.activations import ShiftedSoftplus, Swish + +__all__ = [ + "DeepGraphMolecularModel", + "SchNet", + "PhysNet", + "MolCT", +] + + +class DeepGraphMolecularModel(nn.Cell): + r"""Basic class for graph neural network (GNN) based deep molecular model + + Args: + num_elements (int): maximum number of atomic types + num_rbf (int): number of the serial of radical basis functions (RBF) + dim_feature (int): dimension of the vectors for atomic embedding + atom_types (ms.Tensor[int], optional): atomic index + rbf_function(nn.Cell, optional): the algorithm to calculate RBF + cutoff_network (nn.Cell, optional): the algorithm to calculate cutoff. + + """ + + def __init__( + self, + num_elements, + min_rbf_dis, + max_rbf_dis, + num_rbf, + dim_feature, + n_interactions, + interactions=None, + unit_length='nm', + activation=None, + rbf_sigma=None, + trainable_rbf=False, + rbf_function=None, + cutoff=None, + cutoff_network=None, + rescale_rbf=False, + use_distances=True, + use_public_filter=False, + use_graph_norm=False, + dis_filter=None + ): + super().__init__() + self.num_elements = num_elements + self.dim_feature = dim_feature + self.num_rbf = num_rbf + self.rbf_function = rbf_function + self.rescale_rbf = rescale_rbf + + self.activation = activation + + self.interaction_types = interactions + if isinstance(interactions, list): + self.n_interactions = len(interactions) + else: + self.n_interactions = n_interactions + + self.unit_length = unit_length + units.set_length_unit(self.unit_length) + + self.use_distances = use_distances + self.use_bonds = False + + self.network_name = 'DeepGraphMolecularModel' + + self.read_all_interactions = False + + # make a lookup table to store embeddings for each element (up to atomic + # number max_z) each of which is a vector of size dim_feature + self.atom_embedding = nn.Embedding( + num_elements, + dim_feature, + use_one_hot=True, + embedding_table=Normal(1.0)) + + self.bond_embedding = [None,] + self.bond_filter = [None,] + + self.use_public_filter = use_public_filter + + self.dis_filter = dis_filter + + self.fixed_atoms = False + + # layer for expanding interatomic distances in a basis + if rbf_function is not None: + self.rbf_function = rbf_function( + d_min=min_rbf_dis, + d_max=max_rbf_dis, + num_rbf=num_rbf, + sigma=rbf_sigma, + trainable=trainable_rbf) + else: + self.rbf_function = None + + self.cutoff_network = None + self.cutoff = None + if cutoff_network is not None: + if cutoff is None: + self.cutoff = max_rbf_dis + else: + self.cutoff = cutoff + self.cutoff_network = get_cutoff( + cutoff_network, + r_max=self.cutoff, + return_mask=True, + reverse=False) + + self.interactions = [None,] + + self.interaction_typenames = [] + + self.use_graph_norm = use_graph_norm + self.use_pub_norm = False + + if self.use_graph_norm: + if self.use_pub_norm: + self.graph_norm = nn.CellList( + [GraphNorm(dim_feature) * self.n_interactions] + ) + else: + self.graph_norm = nn.CellList( + [GraphNorm(dim_feature) for _ in range(self.n_interactions)] + ) + else: + self.graph_norm = None + + self.decoder = 'halve' + self.merge_method = None + self.far_type = None + + self.zeros = P.Zeros() + self.ones = P.Ones() + + def print_info(self): + """print info""" + print('---with GNN-based deep molecular model: ', self.network_name) + print('------with atom embedding size: ' + str(self.num_elements)) + print('------with cutoff distance: ' + + str(self.cutoff) + ' ' + self.unit_length) + print('------with number of RBF functions: ' + str(self.num_rbf)) + print('------with bond connection: ' + + ('Yes' if self.use_bonds else 'No')) + print('------with feature dimension: ' + str(self.dim_feature)) + print('------with interaction layers:') + for i, inter in enumerate(self.interactions): + print('------' + str(i + 1) + '. ' + inter.name + + '(' + self.interaction_typenames[i] + ')') + inter.print_info() + print('------with total layers: ' + str(len(self.interactions))) + print('------output all interaction layers: ' + + ('Yes'if self.read_all_interactions else 'No')) + + def set_fixed_atoms(self, fixed_atoms=True): + self.fixed_atoms = fixed_atoms + + def set_fixed_neighbors(self, flag=True): + for interaction in self.interactions: + interaction.set_fixed_neighbors(flag) + + def _calc_cutoffs( + self, + r_ij=1, + neighbor_mask=None, + bonds=None, + bond_mask=None, + atom_mask=None): + """_calc_cutoffs""" + + self.bonds_t = bonds + self.bond_mask_t = bond_mask + self.atom_mask_t = atom_mask + + if self.cutoff_network is None: + return F.ones_like(r_ij), neighbor_mask + return self.cutoff_network(r_ij, neighbor_mask) + + def _get_rbf(self, dis): + """_get_rbf""" + # expand interatomic distances (for example, Gaussian smearing) + if self.rbf_function is None: + rbf = F.expand_dims(dis, -1) + else: + rbf = self.rbf_function(dis) + + if self.rescale_rbf: + rbf = rbf * 2.0 - 1.0 + + if self.dis_filter is not None: + return self.dis_filter(rbf) + return rbf + + def _get_self_rbf(self): + return 0 + + def construct( + self, + r_ij=1, + atom_types=None, + atom_mask=None, + neighbors=None, + neighbor_mask=None, + bonds=None, + bond_mask=None): + """Compute interaction output. + + Args: + r_ij (ms.Tensor[float], [B, A, N]): interatomic distances of (N_b, N_a, N_nbh) shape. + neighbors (ms.Tensor[int]): indices of neighbors of (N_b, N_a, N_nbh) shape. + neighbor_mask (ms.Tensor[bool], optional): mask to filter out non-existing neighbors + introduced via padding. + atom_types (ms.Tensor[int], optional): atomic index + + Returns: + torch.Tensor: block output with (N_b, N_a, N_basis) shape. + + """ + + if self.fixed_atoms: + exones = self.ones((r_ij.shape[0], 1, 1), r_ij.dtype) + e = exones * self.atom_embedding(atom_types) + if atom_mask is not None: + atom_mask = (exones * atom_mask) > 0 + else: + e = self.atom_embedding(atom_types) + + if self.use_distances: + f_ij = self._get_rbf(r_ij) + f_ii = self._get_self_rbf() + else: + f_ii = 1 + f_ij = 1 + + b_ii = 0 + b_ij = 0 + + # apply cutoff + c_ij, mask = self._calc_cutoffs( + r_ij, neighbor_mask, bonds, bond_mask, atom_mask) + + # continuous-filter convolution interaction block followed by Dense + # layer + x = e + n_interactions = len(self.interactions) + xlist = [] + for i in range(n_interactions): + x = self.interactions[i]( + x, e, f_ii, f_ij, b_ii, b_ij, c_ij, neighbors, mask) + + if self.use_graph_norm: + x = self.graph_norm[i](x) + + if self.read_all_interactions: + xlist.append(x) + + if self.read_all_interactions: + return x, xlist + return x, None + + +class SchNet(DeepGraphMolecularModel): + r"""SchNet Model. + + References: + Schütt, K. T.; Sauceda, H. E.; Kindermans, P.-J.; Tkatchenko, A.; Müller K.-R., + SchNet - a deep learning architecture for molceules and materials. + The Journal of Chemical Physics 148 (24), 241722. 2018. + + Args: + + num_elements (int): maximum number of atomic types + num_rbf (int): number of the serial of radical basis functions (RBF) + dim_feature (int): dimension of the vectors for atomic embedding + dim_filter (int): dimension of the vectors for filters used in continuous-filter convolution. + n_interactions (int, optional): number of interaction blocks. + max_distance (float): the maximum distance to calculate RBF. + atom_types (ms.Tensor[int], optional): atomic index + rbf_function(nn.Cell, optional): the algorithm to calculate RBF + cutoff_network (nn.Cell, optional): the algorithm to calculate cutoff. + normalize_filter (bool, optional): if True, divide aggregated filter by number + of neighbors over which convolution is applied. + coupled_interactions (bool, optional): if True, share the weights across + interaction blocks and filter-generating networks. + trainable_gaussians (bool, optional): If True, widths and offset of Gaussian + functions are adjusted during training process. + + """ + + def __init__( + self, + num_elements=100, + dim_feature=64, + min_rbf_dis=0.02, + max_rbf_dis=0.5, + num_rbf=32, + dim_filter=64, + n_interactions=3, + activation=ShiftedSoftplus(), + unit_length='nm', + rbf_sigma=None, + rbf_function=GaussianSmearing, + cutoff=None, + cutoff_network='cosine', + normalize_filter=False, + coupled_interactions=False, + trainable_rbf=False, + use_graph_norm=False, + ): + super().__init__( + num_elements=num_elements, + dim_feature=dim_feature, + min_rbf_dis=min_rbf_dis, + max_rbf_dis=max_rbf_dis, + num_rbf=num_rbf, + n_interactions=n_interactions, + activation=activation, + unit_length=unit_length, + rbf_sigma=rbf_sigma, + rbf_function=rbf_function, + cutoff=cutoff, + cutoff_network=cutoff_network, + rescale_rbf=False, + use_public_filter=False, + use_graph_norm=use_graph_norm, + trainable_rbf=trainable_rbf, + ) + self.network_name = 'SchNet' + + # block for computing interaction + if coupled_interactions: + self.interaction_typenames = ['D0',] * self.n_interactions + # use the same SchNetInteraction instance (hence the same weights) + self.interactions = nn.CellList( + [ + SchNetInteraction( + dim_feature=dim_feature, + num_rbf=num_rbf, + dim_filter=dim_filter, + activation=self.activation, + normalize_filter=normalize_filter, + ) + ] + * self.n_interactions + ) + else: + self.interaction_typenames = [ + 'D' + str(i) for i in range(self.n_interactions)] + # use one SchNetInteraction instance for each interaction + self.interactions = nn.CellList( + [ + SchNetInteraction( + dim_feature=dim_feature, + num_rbf=num_rbf, + dim_filter=dim_filter, + activation=self.activation, + normalize_filter=normalize_filter, + ) + for _ in range(self.n_interactions) + ] + ) + + +class PhysNet(DeepGraphMolecularModel): + r"""PhysNet Model + + References: + Unke, O. T. and Meuwly, M., + PhysNet: A neural network for predicting energyies, forces, dipole moments, and partial charges. + The Journal of Chemical Theory and Computation 2019, 15(6), 3678-3693. + + Args: + + num_elements (int): maximum number of atomic types + num_rbf (int): number of the serial of radical basis functions (RBF) + dim_feature (int): dimension of the vectors for atomic embedding + dim_filter (int): dimension of the vectors for filters used in continuous-filter convolution. + n_interactions (int, optional): number of interaction blocks. + max_distance (float): the maximum distance to calculate RBF. + atom_types (ms.Tensor[int], optional): atomic index + rbf_function(nn.Cell, optional): the algorithm to calculate RBF + cutoff_network (nn.Cell, optional): the algorithm to calculate cutoff. + normalize_filter (bool, optional): if True, divide aggregated filter by number + of neighbors over which convolution is applied. + coupled_interactions (bool, optional): if True, share the weights across + interaction blocks and filter-generating networks. + trainable_gaussians (bool, optional): If True, widths and offset of Gaussian + functions are adjusted during training process. + + """ + + def __init__( + self, + num_elements=100, + min_rbf_dis=0.02, + max_rbf_dis=1, + num_rbf=64, + dim_feature=128, + n_interactions=5, + n_inter_residual=3, + n_outer_residual=2, + unit_length='nm', + activation=ShiftedSoftplus(), + rbf_sigma=None, + rbf_function=GaussianSmearing, + cutoff=None, + cutoff_network='smooth', + use_graph_norm=False, + coupled_interactions=False, + trainable_rbf=False, + ): + super().__init__( + num_elements=num_elements, + dim_feature=dim_feature, + min_rbf_dis=min_rbf_dis, + max_rbf_dis=max_rbf_dis, + num_rbf=num_rbf, + n_interactions=n_interactions, + activation=activation, + rbf_sigma=rbf_sigma, + unit_length=unit_length, + rbf_function=rbf_function, + cutoff=cutoff, + cutoff_network=cutoff_network, + rescale_rbf=False, + use_graph_norm=use_graph_norm, + use_public_filter=False, + trainable_rbf=trainable_rbf, + ) + self.network_name = 'PhysNet' + + # block for computing interaction + if coupled_interactions: + self.interaction_typenames = ['D0',] * self.n_interactions + # use the same SchNetInteraction instance (hence the same weights) + self.interactions = nn.CellList( + [ + PhysNetModule( + num_rbf=num_rbf, + dim_feature=dim_feature, + activation=self.activation, + n_inter_residual=n_inter_residual, + n_outer_residual=n_outer_residual, + ) + ] + * self.n_interactions + ) + else: + self.interaction_typenames = [ + 'D' + str(i) for i in range(self.n_interactions)] + # use one SchNetInteraction instance for each interaction + self.interactions = nn.CellList( + [ + PhysNetModule( + num_rbf=num_rbf, + dim_feature=dim_feature, + activation=self.activation, + n_inter_residual=n_inter_residual, + n_outer_residual=n_outer_residual, + ) + for _ in range(self.n_interactions) + ] + ) + + self.readout = None + + def set_fixed_neighbors(self, flag=True): + for interaction in self.interactions: + interaction.set_fixed_neighbors(flag) + + +class MolCT(DeepGraphMolecularModel): + r"""Molecular Configuration Transformer (MolCT) Model + + References: + Zhang, J.; Zhou, Y.; Lei, Y.-K.; Yang, Y. I.; Gao, Y. Q., + Molecular CT: unifying geometry and representation learning for molecules at different scales + ArXiv: 2012.11816 + + Args: + + + + """ + + def __init__( + self, + num_elements=100, + min_rbf_dis=0.05, + max_rbf_dis=1, + num_rbf=32, + dim_feature=64, + n_interactions=3, + interactions=None, + n_heads=8, + max_cycles=10, + activation=Swish(), + unit_length='nm', + self_dis=None, + rbf_sigma=None, + rbf_function=LogGaussianDistribution, + cutoff=None, + cutoff_network='smooth', + use_distances=True, + use_bonds=False, + num_bond_types=16, + public_dis_filter=True, + public_bond_filter=True, + use_feed_forward=False, + trainable_gaussians=False, + use_pondering=True, + fixed_cycles=False, + rescale_rbf=True, + use_graph_norm=False, + use_time_embedding=True, + coupled_interactions=False, + use_mcr=False, + debug=False, + ): + super().__init__( + num_elements=num_elements, + dim_feature=dim_feature, + min_rbf_dis=min_rbf_dis, + max_rbf_dis=max_rbf_dis, + n_interactions=n_interactions, + interactions=interactions, + activation=activation, + num_rbf=num_rbf, + unit_length=unit_length, + rbf_sigma=rbf_sigma, + rbf_function=rbf_function, + cutoff=cutoff, + cutoff_network=cutoff_network, + rescale_rbf=rescale_rbf, + use_graph_norm=use_graph_norm, + use_public_filter=public_dis_filter, + ) + self.network_name = 'MolCT' + self.max_distance = max_rbf_dis + self.min_distance = min_rbf_dis + self.use_distances = use_distances + self.trainable_gaussians = trainable_gaussians + self.use_mcr = use_mcr + self.debug = debug + + if self_dis is None: + self.self_dis = self.min_distance + else: + self.self_dis = self_dis + + self.self_dis_tensor = Tensor([self.self_dis], ms.float32) + + self.n_heads = n_heads + + if use_time_embedding: + time_embedding = self._get_time_signal(max_cycles, dim_feature) + else: + time_embedding = [0 for _ in range(max_cycles)] + + self.use_bonds = use_bonds + + use_dis_inter = False + use_bond_inter = False + use_mix_inter = False + if self.interaction_types is not None: + self.use_distances = False + self.use_bonds = False + for itype in self.interaction_types: + if itype == 'dis': + use_dis_inter = True + self.use_distances = True + elif itype == 'bond': + use_bond_inter = True + self.use_bonds = True + elif itype == 'mix': + use_mix_inter = True + self.use_distances = True + self.use_bonds = True + else: + raise ValueError( + '"interactions" must be "dis", "bond" or "mix"') + else: + if self.use_distances and self.use_bonds: + use_mix_inter = True + elif self.use_distances: + use_dis_inter = True + elif self.use_bonds: + use_bond_inter = True + else: + raise ValueError( + '"use_bonds" and "use_distances" cannot be both "False"!') + + inter_bond_filter = False + if self.use_bonds: + self.bond_embedding = nn.Embedding( + num_bond_types, + dim_feature, + use_one_hot=True, + embedding_table=Normal(1.0)) + if public_bond_filter: + self.bond_filter = Residual(dim_feature, activation=activation) + else: + inter_bond_filter = True + + inter_dis_filter = False + if self.use_distances: + if self.use_public_filter: + self.dis_filter = ResFilter( + num_rbf, dim_feature, self.activation) + # self.dis_filter = Filter(num_rbf,dim_feature,None) + # self.dis_filter = Dense(num_rbf,dim_feature,has_bias=True,activation=None) + else: + self.dis_filter = Dense( + num_rbf, dim_feature, has_bias=True, activation=None) + inter_dis_filter = True + + interaction_list = [] + if coupled_interactions: + if use_dis_inter: + self.dis_interaction = NeuralInteractionUnit( + dim_feature=dim_feature, + num_rbf=num_rbf, + n_heads=n_heads, + activation=self.activation, + max_cycles=max_cycles, + time_embedding=time_embedding, + use_pondering=use_pondering, + use_distances=True, + use_bonds=False, + use_dis_filter=inter_dis_filter, + use_bond_filter=False, + fixed_cycles=fixed_cycles, + use_feed_forward=use_feed_forward, + ) + else: + self.dis_interaction = None + + if use_bond_inter: + self.bond_interaction = NeuralInteractionUnit( + dim_feature=dim_feature, + num_rbf=num_rbf, + n_heads=n_heads, + activation=self.activation, + max_cycles=max_cycles, + time_embedding=time_embedding, + use_pondering=use_pondering, + use_distances=False, + use_bonds=True, + use_dis_filter=False, + use_bond_filter=inter_bond_filter, + fixed_cycles=fixed_cycles, + use_feed_forward=use_feed_forward, + ) + else: + self.bond_interaction = None + + if use_mix_inter: + self.mix_interaction = NeuralInteractionUnit( + dim_feature=dim_feature, + num_rbf=num_rbf, + n_heads=n_heads, + activation=self.activation, + max_cycles=max_cycles, + time_embedding=time_embedding, + use_pondering=use_pondering, + use_distances=True, + use_bonds=True, + use_dis_filter=inter_dis_filter, + use_bond_filter=inter_bond_filter, + fixed_cycles=fixed_cycles, + use_feed_forward=use_feed_forward, + ) + else: + self.mix_interaction = None + + if self.interaction_types is not None: + for inter in self.interaction_types: + if inter == 'dis': + interaction_list.append(self.dis_interaction) + self.interaction_typenames.append('D0') + elif inter == 'bond': + interaction_list.append(self.bond_interaction) + self.interaction_typenames.append('B0') + else: + interaction_list.append(self.mix_interaction) + self.interaction_typenames.append('M0') + else: + if use_dis_inter: + interaction_list = [ + self.dis_interaction * self.n_interactions] + self.interaction_typenames = ['D0',] * self.n_interactions + elif use_bond_inter: + interaction_list = [ + self.bond_interaction * self.n_interactions] + self.interaction_typenames = ['B0',] * self.n_interactions + else: + interaction_list = [ + self.mix_interaction * self.n_interactions] + self.interaction_typenames = ['M0',] * self.n_interactions + else: + if self.interaction_types is not None: + did = 0 + bid = 0 + mid = 0 + for inter in self.interaction_types: + use_distances = False + use_bonds = False + use_dis_filter = False + use_bond_filter = False + if inter == 'dis': + use_distances = True + use_dis_filter = inter_dis_filter + self.interaction_typenames.append('D' + str(did)) + did += 1 + elif inter == 'bond': + use_bonds = True + self.interaction_typenames.append('B' + str(bid)) + use_bond_filter = inter_bond_filter + bid += 1 + elif inter == 'mix': + use_distances = True + use_bonds = True + use_dis_filter = inter_dis_filter + use_bond_filter = inter_bond_filter + self.interaction_typenames.append('M' + str(mid)) + mid += 1 + + interaction_list.append( + NeuralInteractionUnit( + dim_feature=dim_feature, + num_rbf=num_rbf, + n_heads=n_heads, + activation=self.activation, + max_cycles=max_cycles, + time_embedding=time_embedding, + use_pondering=use_pondering, + use_distances=use_distances, + use_bonds=use_bonds, + use_dis_filter=use_dis_filter, + use_bond_filter=use_bond_filter, + fixed_cycles=fixed_cycles, + use_feed_forward=use_feed_forward, + ) + ) + else: + if use_dis_inter: + t = 'D' + elif use_bond_inter: + t = 'B' + else: + t = 'M' + + self.interaction_typenames = [ + t + str(i) for i in range(self.n_interactions)] + + interaction_list = [ + NeuralInteractionUnit( + dim_feature=dim_feature, + num_rbf=num_rbf, + n_heads=n_heads, + activation=self.activation, + max_cycles=max_cycles, + time_embedding=time_embedding, + use_pondering=use_pondering, + use_distances=self.use_distances, + use_bonds=self.use_bonds, + use_dis_filter=inter_dis_filter, + use_bond_filter=inter_bond_filter, + fixed_cycles=fixed_cycles, + use_feed_forward=use_feed_forward, + ) + for i in range(self.n_interactions) + ] + + self.n_interactions = len(interaction_list) + self.interactions = nn.CellList(interaction_list) + + self.lmax_label = [] + for i in range(n_interactions): + self.lmax_label.append('l' + str(i) + '_cycles') + + self.fill = P.Fill() + self.concat = P.Concat(-1) + self.reducesum = P.ReduceSum() + self.reducemax = P.ReduceMax() + self.tensor_summary = P.TensorSummary() + self.scalar_summary = P.ScalarSummary() + + def set_fixed_neighbors(self, flag=True): + for interaction in self.interactions: + interaction.set_fixed_neighbors(flag) + + def _calc_cutoffs( + self, + r_ij=1, + neighbor_mask=None, + bonds=None, + bond_mask=None, + atom_mask=None): + mask = None + + if self.use_distances: + if neighbor_mask is not None: + mask = self.concat((atom_mask, neighbor_mask)) + + if self.cutoff_network is None: + new_shape = (r_ij.shape[0], r_ij.shape[1] + 1, r_ij.shape[2]) + return self.fill(r_ij.dtype, new_shape, 1.0), mask + rii_shape = r_ij.shape[:-1] + (1,) + r_ii = self.fill(r_ij.dtype, rii_shape, self.self_dis) + if atom_mask is not None: + r_large = F.ones_like(r_ii) * 5e4 + r_ii = F.select(atom_mask, r_ii, r_large) + # [B, A, N'] + r_ij = self.concat((r_ii, r_ij)) + + return self.cutoff_network(r_ij, mask) + if bond_mask is not None: + mask = self.concat((atom_mask, bond_mask)) + return F.cast(mask > 0, ms.float32), mask + + def _get_self_rbf(self): + f_ii = self._get_rbf(self.self_dis_tensor) + return f_ii + + def _get_time_signal( + self, + length, + channels, + min_timescale=1.0, + max_timescale=1.0e4): + """ + Generates a [1, length, channels] timing signal consisting of sinusoids + Adapted from: + https://github.com/andreamad8/Universal-Transformer-Pytorch/blob/master/models/common_layer.py + """ + position = np.arange(length) + num_timescales = channels // 2 + log_timescale_increment = (np.log( + float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1)) + inv_timescales = min_timescale * \ + np.exp(np.arange(num_timescales).astype(np.float) * -log_timescale_increment) + scaled_time = np.expand_dims( + position, 1) * np.expand_dims(inv_timescales, 0) + + signal = np.concatenate( + [np.sin(scaled_time), np.cos(scaled_time)], axis=1) + signal = np.pad(signal, [[0, 0], [0, channels % 2]], + 'constant', constant_values=[0.0, 0.0]) + + return Tensor(signal, ms.float32) diff --git a/MindSPONGE/mindsponge/md/cybertron/neighbors.py b/MindSPONGE/mindsponge/md/cybertron/neighbors.py new file mode 100644 index 0000000000000000000000000000000000000000..385e20f6183d6254a09f58b9ac9f50007b8ef66e --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/neighbors.py @@ -0,0 +1,145 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""neighbors""" + +import mindspore as ms +from mindspore import nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +from cybertron.units import units + +__all__ = [ + "GatherNeighbors", + "Distances", +] + + +class GatherNeighbors(nn.Cell): + r"""Gathering the positions of every atom to its neighbors. + + Args: + + """ + + def __init__(self, dim, fixed_neigh=False): + super().__init__() + self.fixed_neigh = fixed_neigh + + self.broad_ones = P.Ones()((1, 1, dim), ms.int32) + + self.gatherd = P.GatherD() + + def construct(self, inputs, neighbors): + """construct""" + # Construct auxiliary index vector + ns = neighbors.shape + + # Get atomic positions of all neighboring indices + + if self.fixed_neigh: + return F.gather(inputs, neighbors[0], -2) + # [B, A, N] -> [B, A*N, 1] + neigh_idx = F.reshape(neighbors, (ns[0], ns[1] * ns[2], -1)) + neigh_idx = neigh_idx * self.broad_ones + # [B, A*N, V] gather from [B, A, V] + outputs = self.gatherd(inputs, 1, neigh_idx) + # [B, A, N, V] + return F.reshape(outputs, (ns[0], ns[1], ns[2], -1)) + + +class Distances(nn.Cell): + r"""Computing distance of every atom to its neighbors. + + Args: + + + """ + + def __init__( + self, + fixed_atoms=False, + dim=3, + long_dis=units.length( + 10, + 'nm')): + super().__init__() + self.fixed_atoms = fixed_atoms + self.reducesum = P.ReduceSum() + self.pow = P.Pow() + self.gatherd = P.GatherD() + self.norm = nn.Norm(-1) + self.long_dis = long_dis + + self.gather_neighbors = GatherNeighbors(dim, fixed_atoms) + self.maximum = P.Maximum() + + # self.ceil = P.Ceil() + self.floor = P.Floor() + + def construct(self, positions, neighbors, neighbor_mask=None, pbcbox=None): + r"""Compute distance of every atom to its neighbors. + + Args: + positions (ms.Tensor[float], [B, A, 3]): atomic Cartesian coordinates + neighbors (ms.Tensor[int], [B, A, N] or [1, A, N]): indices of neighboring atoms to consider + neighbor_mask (ms.Tensor[bool], [B, A, N] or [1, A, N]): boolean mask for neighbor + positions. Required for the stable computation of forces in + molecules with different sizes. + pbcbox (ms.Tensor[float], [B, 3] or [1, 3]) + + Returns: + ms.Tensor[float]: layer output of (N_b x N_at x N_nbh) shape. + + """ + + pos_xyz = self.gather_neighbors(positions, neighbors) + pos_diff = pos_xyz - F.expand_dims(positions, -2) + + if pbcbox is not None: + # [B, 3] -> [B, 1, 1, 3] or [1, 3] -> [1, 1, 1, 3] + pbcbox = F.expand_dims(pbcbox, -2) + pbcbox = F.expand_dims(pbcbox, -2) + halfbox = F.ones_like(pos_diff) * (pbcbox / 2) + lmask = pos_diff > halfbox + smask = pos_diff < -halfbox + + if lmask.any(): + nbox = self.floor(pos_diff / pbcbox - 0.5) + 1 + pos = pos_diff - nbox * pbcbox + pos_diff = F.select(lmask, pos, pos_diff) + + if smask.any(): + # nbox = self.ceil(-pos_diff/pbcbox - 0.5) + nbox = self.floor(-pos_diff / pbcbox - 0.5) + 1 + pos = pos_diff + nbox * pbcbox + pos_diff = F.select(smask, pos, pos_diff) + + if neighbor_mask is not None: + large_diff = pos_diff + self.long_dis + smask = (F.ones_like(pos_diff) * + F.expand_dims(neighbor_mask, -1)) > 0 + pos_diff = F.select(smask, pos_diff, large_diff) + + return self.norm(pos_diff) diff --git a/MindSPONGE/mindsponge/md/cybertron/rbf.py b/MindSPONGE/mindsponge/md/cybertron/rbf.py new file mode 100644 index 0000000000000000000000000000000000000000..f163fb21a13124f502f75a13bbbf058c0cd22d10 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/rbf.py @@ -0,0 +1,184 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""rbf""" + +import math +import numpy as np +import mindspore as ms +from mindspore import nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +from cybertron.units import units + +__all__ = [ + "GaussianSmearing", + "LogGaussianDistribution", +] + +# radial_filter in RadialDistribution + + +class GaussianSmearing(nn.Cell): + r"""Smear layer using a set of Gaussian functions. + + Args: + start (float, optional): center of first Gaussian function, :math:`\mu_0`. + stop (float, optional): center of last Gaussian function, :math:`\mu_{N_g}` + n_gaussians (int, optional): total number of Gaussian functions, :math:`N_g`. + centered (bool, optional): If True, Gaussians are centered at the origin and + the offsets are used to as their widths (used e.g. for angular functions). + trainable (bool, optional): If True, widths and offset of Gaussian functions + are adjusted during training process. + + """ + + def __init__( + self, + d_min=0, + d_max=units.length( + 1, + 'nm'), + num_rbf=32, + sigma=None, + centered=False, + trainable=False): + super().__init__() + # compute offset and width of Gaussian functions + offset = Tensor(np.linspace(d_min, d_max, num_rbf), ms.float32) + + if sigma is None: + sigma = (d_max - d_min) / (num_rbf - 1) + + width = sigma * F.ones_like(offset) + + self.width = width + self.offset = offset + self.centered = centered + + if trainable: + self.width = ms.Parameter(width, "widths") + self.offset = ms.Parameter(offset, "offset") + + def construct(self, distances): + """Compute smeared-gaussian distance values. + + Args: + distances (torch.Tensor): interatomic distance values of + (N_b x N_at x N_nbh) shape. + + Returns: + torch.Tensor: layer output of (N_b x N_at x N_nbh x N_g) shape. + + """ + + ex_dis = F.expand_dims(distances, -1) + if not self.centered: + # compute width of Gaussian functions (using an overlap of 1 + # STDDEV) + coeff = -0.5 / F.square(self.width) + # Use advanced indexing to compute the individual components + diff = ex_dis - self.offset + else: + # if Gaussian functions are centered, use offsets to compute widths + coeff = -0.5 / F.square(self.offset) + # if Gaussian functions are centered, no offset is subtracted + diff = ex_dis + # compute smear distance values + exp = P.Exp() + gauss = exp(coeff * F.square(diff)) + return gauss + + +class LogGaussianDistribution(nn.Cell): + """LogGaussianDistribution""" + def __init__( + self, + d_min=units.length(0.05, 'nm'), + d_max=units.length(1, 'nm'), + num_rbf=32, + sigma=None, + trainable=False, + min_cutoff=False, + max_cutoff=False, + ): + super().__init__() + if d_max <= d_min: + raise ValueError( + 'The argument "d_max" must be larger' + + 'than the argument "d_min" in LogGaussianDistribution!') + + if d_min <= 0: + raise ValueError('The argument "d_min" must be ' + + ' larger than 0 in LogGaussianDistribution!') + self.trainable = trainable + self.d_max = d_max + self.d_min = d_min / d_max + self.min_cutoff = min_cutoff + self.max_cutoff = max_cutoff + + self.log = P.Log() + self.exp = P.Exp() + self.max = P.Maximum() + self.min = P.Minimum() + self.zeroslike = P.ZerosLike() + self.oneslike = P.OnesLike() + + # linspace = nn.LinSpace(log_dmin,0,n_gaussians) + + log_dmin = math.log(self.d_min) + # self.centers = linspace() + # self.ones = self.oneslike(self.centers) + centers = np.linspace(log_dmin, 0, num_rbf) + self.centers = Tensor(centers, ms.float32) + ones = np.ones_like(centers) + self.ones = Tensor(ones, ms.float32) + + if sigma is None: + sigma = -log_dmin / (num_rbf - 1) + self.rescale = -0.5 / (sigma * sigma) + + def construct(self, distance): + """construct""" + dis = distance / self.d_max + + if self.min_cutoff: + dis = self.max(dis, self.d_min) + + exdis = F.expand_dims(dis, -1) + rbfdis = exdis * self.ones + + log_dis = self.log(rbfdis) + log_diff = log_dis - self.centers + log_diff2 = F.square(log_diff) + log_gauss = self.exp(self.rescale * log_diff2) + + if self.max_cutoff: + ones = self.oneslike(exdis) + zeros = self.zeroslike(exdis) + cuts = F.select(exdis < 1.0, ones, zeros) + log_gauss = log_gauss * cuts + + return log_gauss diff --git a/MindSPONGE/mindsponge/md/cybertron/readouts.py b/MindSPONGE/mindsponge/md/cybertron/readouts.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ada8252d03ddc5c00b37aa4c92b1573b4cf0c2 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/readouts.py @@ -0,0 +1,1152 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""readouts""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +from cybertron.units import units +from cybertron.neighbors import GatherNeighbors +from cybertron.base import SmoothReciprocal +from cybertron.cutoff import get_cutoff +from cybertron.aggregators import get_aggregator, get_list_aggregator +from cybertron.decoders import get_decoder + +__all__ = [ + "Readout", + "AtomwiseReadout", + "GraphReadout", + "LongeRangeReadout", + "PairwiseReadout", + "CoulombReadout", + "InteractionsAggregator", +] + + +class InteractionsAggregator(nn.Cell): + """InteractionsAggregator""" + def __init__(self, + n_in, + n_out, + n_interactions, + activation=None, + list_aggregator='sum', + n_aggregator_hiddens=0, + decoders=None, + ): + super().__init__() + + self.n_interactions = n_interactions + self.decoders = None + + if decoders is not None: + if isinstance(decoders, (tuple, list)): + self.decoders = nn.CellList([ + get_decoder(decoders[i], n_in, n_out, activation) + for i in range(n_interactions) + ]) + elif isinstance(decoders, str): + self.decoders = nn.CellList([ + get_decoder(decoders, n_in, n_out, activation) + for i in range(n_interactions) + ]) + else: + raise TypeError( + "Unsupported Decoder type '{}'.".format( + type(decoders))) + + self.list_aggregator = get_list_aggregator( + list_aggregator, n_out, n_interactions, n_aggregator_hiddens, activation) + if self.list_aggregator is None: + raise TypeError( + "ListAggregator cannot be None at InteractionsAggregator") + + def construct(self, xlist, atom_mask=None): + if self.decoders is not None: + ylist = [] + n_interactions = len(xlist) + for i in range(n_interactions): + y = self.decoders[i](xlist[i]) + ylist.append(y) + return self.list_aggregator(ylist, atom_mask) + return self.list_aggregator(xlist, atom_mask) + + +class Readout(nn.Cell): + """Readout""" + def __init__(self, + n_in, + n_out=1, + atom_scale=1, + atom_shift=0, + mol_scale=1, + mol_shift=0, + axis=-2, + atom_ref=None, + scaled_by_atoms_number=True, + averaged_by_atoms_number=False, + activation=None, + decoder=None, + aggregator=None, + unit_energy='kJ/mol', + multi_aggregators=False, + read_all_interactions=False, + n_interactions=None, + interactions_aggregator='sum', + n_aggregator_hiddens=0, + interaction_decoders=None, + ): + super().__init__() + + self.name = 'Readout' + + if unit_energy is None: + self.output_is_energy = False + else: + self.output_is_energy = True + if not isinstance(unit_energy, str): + raise TypeError('Type of unit_energy must be str') + unit_energy = units.check_energy_unit(unit_energy) + units.set_energy_unit(unit_energy) + self.unit_energy = unit_energy + + self.averaged_by_atoms_number = averaged_by_atoms_number + + self.atom_ref = atom_ref + + if not isinstance(n_in, int): + raise TypeError('Type of n_in must be int') + self.n_in = n_in + + n_out, num_output_dim = self._check_type_and_number( + n_out, 'n_out', int) + self.n_out = n_out + self.multi_n_out = n_out + self.total_out = n_out + + self.multi_output_number = False + if num_output_dim > 1: + n_out = list(set(n_out)) + if len(n_out) > 1: + self.multi_output_number = True + else: + n_out = n_out[0] + + self.n_out = n_out + + atom_scale, num_atom_scale = self._check_type_and_number( + atom_scale, 'atom_scale', (int, float, Tensor)) + self.atom_scale = atom_scale + + atom_shift, num_atom_shift = self._check_type_and_number( + atom_shift, 'atom_shift', (int, float, Tensor)) + self.atom_shift = atom_shift + + if not isinstance(mol_scale, (float, int, Tensor)): + raise TypeError('Type of mol_scale must be float, int or Tensor') + self.mol_scale = mol_scale + + if not isinstance(mol_shift, (float, int, Tensor)): + raise TypeError('Type of mol_shift must be float, int or Tensor') + self.mol_shift = mol_shift + + if not isinstance(axis, int): + raise TypeError('Type of mol_shift must be int') + self.axis = axis + + activation, num_activation = self._check_type_and_number( + activation, 'activation') + self.activation = activation + + if not isinstance(scaled_by_atoms_number, bool): + raise TypeError('Type of scaled_by_atoms_number must be bool') + self.scaled_by_atoms_number = scaled_by_atoms_number + + if not isinstance(averaged_by_atoms_number, bool): + raise TypeError('Type of averaged_by_atoms_number must be bool') + + if read_all_interactions and interaction_decoders is not None: + decoder = None + + self.decoder = None + self.multi_decoders = False + self.num_decoder = 1 + if isinstance(decoder, (tuple, list)): + self.multi_decoders = True + self.num_decoder = len(decoder) + if self.num_decoder == 1: + decoder = decoder[0] + elif self.num_decoder == 0: + raise ValueError('Number of decoder cannot be zero') + + if num_output_dim > 1: + self.multi_decoders = True + if self.num_decoder != num_output_dim: + if self.num_decoder == 1: + self.decoder = [decoder,] * num_output_dim + self.num_decoder = num_output_dim + else: + raise ValueError('Number of decoder mismatch') + + self.aggregator = None + self.multi_aggregators = multi_aggregators + self.num_aggregator = 1 + if isinstance(aggregator, (tuple, list)): + self.multi_aggregators = True + self.num_aggregator = len(aggregator) + if self.num_aggregator == 1: + aggregator = aggregator[0] + if self.num_aggregator == 0: + raise ValueError('Number of aggregator cannot be zero') + + if self.num_aggregator == 1 and self.multi_aggregators: + aggregator = [aggregator,] * self.num_decoder + self.num_aggregator = self.num_decoder + self.aggregator = aggregator + + if self.multi_decoders: + if self.num_aggregator != self.num_decoder and self.num_aggregator != 1: + raise ValueError('Number of aggregator mismatch') + else: + if self.multi_aggregators: + raise ValueError( + 'multi aggregators must be used with multi decoders') + + self.split_slice = () + if self.multi_decoders: + if num_output_dim != self.num_decoder: + if num_output_dim == 1: + self.multi_n_out = [n_out,] * self.num_decoder + else: + raise ValueError('Number of n_out mismatch') + + sect = 0 + for i in range(len(self.multi_n_out) - 1): + sect = self.multi_n_out[i] + sect + self.split_slice += (sect,) + + self.total_out = 0 + for n in self.multi_n_out: + self.total_out += n + + if num_activation != self.num_decoder: + if num_activation == 1: + self.activation = [activation,] * self.num_decoder + else: + raise ValueError('Number of activation missmatch') + else: + if num_atom_scale > 1: + raise ValueError('Number of atom_scale mismatch') + if num_atom_shift > 1: + raise ValueError('Number of atom_shift mismatch') + if num_activation > 1: + raise ValueError('Number of activation mismatch') + + self.str_unit_energy = ( + " " + unit_energy) if self.output_is_energy else "" + + self.split = P.Split(-1, self.num_decoder) + + self.multi_atom_scale = None + self.multi_atom_shift = None + self.multi_atom_ref = None + if self.multi_decoders: + self.multi_atom_scale = self._split_by_decoders( + self.atom_scale, 'atom_scale') + self.multi_atom_shift = self._split_by_decoders( + self.atom_shift, 'atom_scale') + if atom_ref is not None: + self.multi_atom_ref = self._split_by_decoders( + self.atom_ref, 'atom_ref') + + self.read_all_interactions = read_all_interactions + self.n_interactions = n_interactions + + if read_all_interactions: + if interaction_decoders is not None: + if decoder is not None: + raise ValueError( + 'decoder and interaction_decoders cannot be used at same time') + if self.multi_decoders: + raise ValueError( + 'Multiple decoders cannot support interaction_decoders') + if n_interactions is None or n_interactions == 0: + raise ValueError( + 'n_interactions must be setup when usingn interaction_decoders') + + self.interaction_decoders = interaction_decoders + self.interactions_aggregator = InteractionsAggregator( + n_in, + n_out, + n_interactions, + activation=activation, + list_aggregator=interactions_aggregator, + n_aggregator_hiddens=n_aggregator_hiddens, + decoders=interaction_decoders, + ) + else: + self.interactions_aggregator = None + self.interaction_decoders = None + + self.concat = P.Concat(-1) + + def print_info(self): + """print info""" + + self._print_plus_info() + + if self.read_all_interactions: + print("------read all interactions with interactions aggregator: " + + str(self.interactions_aggregator.list_aggregator)) + if self.interaction_decoders: + print("---------with independent decoders: ") + for i in range(self.n_interactions): + decoder = self.interactions_aggregator.decoders[i] + print("---------" + + str(i + + 1) + + '. decoder "' + + str(decoder) + + '" with activation "' + + str(decoder.activation) + + '".') + else: + print("---------without independent decoder. ") + else: + print("------read last interactions:") + + if self.multi_decoders: + print("------with " + + str(self.multi_decoders) + + " multiple decoders and " + + ("aggregators." if self.multi_aggregators else ('common "' + + str(self.aggregator) + + '" aggreator.'))) + for i in range(self.num_decoder): + print("------" + + str(i + + 1) + + '. decoder "' + + str(self.decoder[i]) + + '" with activation "' + + str(self.decoder[i].activation) + + '".') + if self.multi_aggregators: + print( + "---------with aggregator: " + + self.aggregator[i].name) + print("---------with readout dimension: " + + str(self.multi_n_out[i])) + print("---------with activation function: " + + str(self.activation[i])) + + print("------with multiple scale and shift:") + for i in range(self.num_decoder): + print("------" + str(i + 1) + + ". output with dimension: " + str(self.multi_n_out[i])) + print("---------with atom scale: " + + str(self.multi_atom_scale[i]) + self.str_unit_energy) + print("---------with atom shift: " + + str(self.multi_atom_shift[i]) + self.str_unit_energy) + print("------scaled by atoms number: " + + str(self.scaled_by_atoms_number)) + + else: + if self.decoder is not None: + print("------with decoder: " + str(self.decoder)) + print("------with activation function: " + + str(self.decoder.activation)) + print("------with readout dimension: " + str(self.n_out)) + print("------with atom scale: " + + str(self.atom_scale) + self.str_unit_energy) + print("------with atom shift: " + + str(self.atom_shift) + self.str_unit_energy) + print("------scaled by atoms number: " + + str(self.scaled_by_atoms_number)) + + print("------with total readout dimension: " + str(self.total_out)) + print("------with molecular scale: " + + str(self.mol_scale) + self.str_unit_energy) + print("------with molecular shift: " + + str(self.mol_shift) + self.str_unit_energy) + print("------averaged by atoms number: " + + ('Yes' if self.averaged_by_atoms_number else 'No')) + + def _print_plus_info(self): + print('------calculate long range interaction: No') + + def _check_type_and_number(self, inputs, name, types=None): + """_check_type_and_number""" + num = 1 + if isinstance(inputs, (tuple, list)): + num = len(inputs) + if num == 0: + raise ValueError("Size of " + name + " cannot be Zeros!") + if num == 1: + inputs = inputs[0] + if (types is not None) and (not isinstance(inputs, types)): + raise TypeError( + "Unsupported " + + name + + " type '{}'.".format( + type(inputs))) + elif (types is not None) and (not isinstance(inputs, types)): + raise TypeError( + "Unsupported " + + name + + " type '{}'.".format( + type(inputs))) + + return inputs, num + + def _split_by_decoders(self, inputs, name): + """_split_by_decoders""" + if self.multi_decoders: + if isinstance(inputs, (tuple, list)): + if len(inputs) != self.num_decoder: + if len(inputs) == 1: + inputs = inputs * self.num_decoder + else: + raise ValueError('Number of ' + name + ' mismatch') + elif isinstance(inputs, (float, int)): + if isinstance(input, float): + inputs = Tensor(inputs, ms.float32) + if isinstance(input, int): + inputs = Tensor(inputs, ms.int32) + inputs = [inputs,] * self.num_decoder + elif isinstance(inputs, Tensor): + if inputs.shape[-1] != self.total_out: + raise ValueError('Last dimension of ' + name + ' mismatch') + + if self.multi_output_number: + inputs = msnp.split(inputs, self.split_slice, -1) + else: + inputs = self.split(inputs) + else: + raise TypeError( + "Unsupported Decoder type '{}'.".format( + type(inputs))) + + return inputs + + +class AtomwiseReadout(Readout): + """ + Predicts atom-wise contributions and accumulates global prediction, e.g. for the + energy. + + Args: + n_in (int): input dimension of representation + n_out (int): output dimension of target property (default: 1) + aggregation_mode (str): one of {sum, avg} (default: sum) + n_layers (int): number of nn in output network (default: 2) + n_neurons (list of int or None): number of neurons in each layer of the output + network. If `None`, divide neurons by 2 in each layer. (default: None) + activation (function): activation function for hidden nn + (default: spk.nn.activations.shifted_softplus) + mean (torch.Tensor or None): mean of property + stddev (torch.Tensor or None): standard deviation of property (default: None) + atomref (torch.Tensor or None): reference single-atom properties. Expects + an (max_z + 1) x 1 array where atomref[Z] corresponds to the reference + property of element Z. The value of atomref[0] must be zero, as this + corresponds to the reference property for for "mask" atoms. (default: None) + outnet (callable): Network used for atomistic outputs. Takes schnetpack input + dictionary as input. Output is not normalized. If set to None, + a pyramidal network is generated automatically. (default: None) + + Returns: + tuple: prediction for property + + If contributions is not None additionally returns atom-wise contributions. + + If derivative is not None additionally returns derivative w.r.t. atom positions. + + """ + + def __init__( + self, + n_in, + n_out=1, + atom_scale=1, + atom_shift=0, + mol_scale=1, + mol_shift=0, + axis=-2, + atom_ref=None, + scaled_by_atoms_number=False, + averaged_by_atoms_number=False, + activation=None, + decoder='halve', + aggregator='sum', + unit_energy='kJ/mol', + multi_aggregators=False, + read_all_interactions=False, + n_interactions=None, + interactions_aggregator='sum', + n_aggregator_hiddens=0, + interaction_decoders=None, + ): + super().__init__( + n_in=n_in, + n_out=n_out, + atom_scale=atom_scale, + atom_shift=atom_shift, + mol_scale=mol_scale, + mol_shift=mol_shift, + axis=axis, + atom_ref=atom_ref, + scaled_by_atoms_number=scaled_by_atoms_number, + averaged_by_atoms_number=averaged_by_atoms_number, + activation=activation, + decoder=decoder, + aggregator=aggregator, + unit_energy=unit_energy, + multi_aggregators=multi_aggregators, + read_all_interactions=read_all_interactions, + n_interactions=n_interactions, + interactions_aggregator=interactions_aggregator, + n_aggregator_hiddens=n_aggregator_hiddens, + interaction_decoders=interaction_decoders, + ) + self.name = 'Atomwise' + + if self.multi_decoders: + decoders = [] + for i in range(self.num_decoder): + decoder_t = get_decoder( + self.decoder[i], + self.n_in, + self.multi_n_out[i], + self.activation[i]) + if decoder_t is not None: + decoders.append(decoder_t) + else: + raise ValueError('Multi decoders cannot include None type') + self.decoder = nn.CellList(decoders) + else: + self.decoder = get_decoder( + decoder, self.n_in, self.n_out, self.activation) + + if self.decoder is None and self.interaction_decoders is None and self.n_in != self.n_out: + raise ValueError( + "When decoder is None, n_out (" + + str(n_out) + + ") must be equal to n_in (" + + str(n_in) + + ")") + + if self.multi_aggregators: + aggregators = [] + for i in range(self.num_aggregator): + aggregator_t = get_aggregator( + self.aggregator[i], self.multi_n_out[i], axis) + if aggregator_t is not None: + aggregators.append(aggregator_t) + else: + raise ValueError( + 'Multi aggregators cannot include None type') + self.aggregator = nn.CellList(aggregators) + + else: + if self.multi_output_number: + agg_dict = {} + for n_out_t in self.n_out: + aggregator_t = get_aggregator(self.aggregator, n_out_t, axis) + if aggregator_t is not None: + agg_dict[n_out_t] = aggregator_t + + aggregators = [] + for i in range(self.num_decoder): + aggregators.append(agg_dict[self.multi_n_out[i]]) + self.aggregator = nn.CellList(aggregators) + self.multi_aggregators = True + else: + self.aggregator = get_aggregator(aggregator, self.n_out, axis) + + def construct( + self, + x, + xlist, + atoms_types=None, + atom_mask=None, + atoms_number=None): + r""" + predicts atomwise property + """ + + if self.read_all_interactions: + x = self.interactions_aggregator(xlist, atom_mask) + + y = None + if self.multi_decoders: + ytuple = () + for i in range(self.num_decoder): + yi = self.decoder[i](x) + if self.multi_aggregators: + yi = yi * \ + self.multi_atom_scale[i] + self.multi_atom_shift[i] + if self.scaled_by_atoms_number: + yi = yi / atoms_number + if self.atom_ref is not None: + yi += F.gather(self.multi_atom_ref[i], atoms_types, 0) + yi = self.aggregator[i](yi, atom_mask, atoms_number) + + ytuple = ytuple + (yi,) + + y = self.concat(ytuple) + + if not self.multi_aggregators: + y = y * self.atom_scale + self.atom_shift + if self.scaled_by_atoms_number: + y = y / atoms_number + if self.atom_ref is not None: + y += F.gather(self.atom_ref, atoms_types, 0) + y = self.aggregator(y, atom_mask, atoms_number) + else: + if self.decoder is not None: + y = self.decoder(x) + else: + y = x + + y = y * self.atom_scale + self.atom_shift + if self.scaled_by_atoms_number: + y = y / atoms_number + + if self.atom_ref is not None: + y += F.gather(self.atom_ref, atoms_types, 0) + + if self.aggregator is not None: + y = self.aggregator(y, atom_mask) + + y = y * self.mol_scale + self.mol_shift + + if self.averaged_by_atoms_number: + if atoms_number is None: + atoms_number = x.shape[self.axis] + y = y / atoms_number + + return y + + +class GraphReadout(Readout): + """ + + Args: + + Returns: + + """ + + def __init__( + self, + n_in, + n_out=1, + atom_scale=1, + atom_shift=0, + mol_scale=1, + mol_shift=0, + axis=-2, + atom_ref=None, + scaled_by_atoms_number=False, + averaged_by_atoms_number=False, + activation=None, + decoder='halve', + aggregator='mean', + unit_energy=None, + multi_aggregators=False, + read_all_interactions=False, + n_interactions=None, + interactions_aggregator='sum', + n_aggregator_hiddens=0, + interaction_decoders=None, + ): + super().__init__( + n_in=n_in, + n_out=n_out, + atom_scale=atom_scale, + atom_shift=atom_shift, + mol_scale=mol_scale, + mol_shift=mol_shift, + axis=axis, + atom_ref=atom_ref, + scaled_by_atoms_number=scaled_by_atoms_number, + averaged_by_atoms_number=averaged_by_atoms_number, + activation=activation, + decoder=decoder, + aggregator=aggregator, + unit_energy=unit_energy, + multi_aggregators=multi_aggregators, + read_all_interactions=read_all_interactions, + n_interactions=n_interactions, + interactions_aggregator=interactions_aggregator, + n_aggregator_hiddens=n_aggregator_hiddens, + interaction_decoders=interaction_decoders, + ) + + self.name = 'Graph' + + if self.interaction_decoders is not None: + raise ValueError('GraphReadout cannot use interaction_decoders') + + if self.multi_aggregators: + aggregators = [] + for i in range(self.num_aggregator): + aggregator_t = get_aggregator( + self.aggregator[i], self.n_in, axis) + if aggregator_t is not None: + aggregators.append(aggregator_t) + else: + raise ValueError( + 'Multi aggregators cannot include None type') + self.aggregator = nn.CellList(aggregators) + else: + self.aggregator = get_aggregator(aggregator, self.n_in, axis) + if self.aggregator is None: + raise ValueError("aggregator cannot be None at GraphReadout") + + if self.multi_decoders: + decoders = [] + for i in range(self.num_decoder): + decoder_t = get_decoder( + self.decoder[i], + self.n_in, + self.multi_n_out[i], + self.activation[i]) + if decoder_t is not None: + decoders.append(decoder_t) + else: + raise ValueError('Multi decoders cannot include None type') + self.decoder = nn.CellList(decoders) + else: + self.decoder = get_decoder( + decoder, self.n_in, self.n_out, self.activation) + if self.decoder is None and n_in != n_out: + raise ValueError( + "When decoder is None, n_out (" + + str(n_out) + + ") must be equal to n_in (" + + str(n_in) + + ")") + + self.reduce_sum = P.ReduceSum() + + def construct( + self, + x, + xlist, + atoms_types=None, + atom_mask=None, + atoms_number=None): + r""" + predicts graph property + """ + + if self.read_all_interactions: + x = self.interactions_aggregator(xlist, atom_mask) + + y = None + if self.multi_decoders: + if self.multi_aggregators: + agg = None + else: + agg = self.aggregator(x, atom_mask, atoms_number) + + ytuple = () + for i in range(self.num_decoder): + if self.multi_aggregators: + agg = self.aggregator[i](x, atom_mask, atoms_number) + + yi = self.decoder[i](agg) + + ytuple = ytuple + (yi,) + + y = self.concat(ytuple) + + y = y * self.atom_scale + self.atom_shift + if self.scaled_by_atoms_number: + y = y / atoms_number + + else: + agg = self.aggregator(x, atom_mask, atoms_number) + + if self.decoder is not None: + y = self.decoder(agg) + else: + y = agg + + y = y * self.atom_scale + self.atom_shift + if self.scaled_by_atoms_number: + y = y / atoms_number + + if self.atom_ref is not None: + ref = F.gather(self.atom_ref, atoms_types, 0) + ref = self.reduce_sum(ref, self.axis) + y += ref + + y = y * self.mol_scale + self.mol_shift + + if self.averaged_by_atoms_number: + if atoms_number is None: + atoms_number = x.shape[self.axis] + y = y / atoms_number + + return y + + +class LongeRangeReadout(Readout): + """LongeRangeReadout""" + def __init__(self, + dim_feature, + atom_scale=1, + atom_shift=0, + mol_scale=1, + mol_shift=0, + axis=-2, + atom_ref=None, + activation=None, + decoder='halve', + longrange_decoder=None, + unit_energy='kcal/mol', + cutoff_function='gaussian', + cutoff_max=units.length(1, 'nm'), + cutoff_min=units.length(0.8, 'nm'), + fixed_neigh=False, + read_all_interactions=False, + n_interactions=None, + interactions_aggregator='sum', + n_aggregator_hiddens=0, + interaction_decoders=None, + ): + super().__init__( + n_in=dim_feature, + n_out=1, + atom_scale=atom_scale, + atom_shift=atom_shift, + mol_scale=mol_scale, + mol_shift=mol_shift, + axis=axis, + atom_ref=atom_ref, + scaled_by_atoms_number=False, + averaged_by_atoms_number=False, + activation=activation, + decoder=decoder, + aggregator='sum', + unit_energy=unit_energy, + multi_aggregators=False, + read_all_interactions=read_all_interactions, + n_interactions=n_interactions, + interactions_aggregator=interactions_aggregator, + n_aggregator_hiddens=n_aggregator_hiddens, + interaction_decoders=interaction_decoders, + ) + + self.name = 'longrange' + self.longrange_decoder = longrange_decoder + self.coulomb_const = units.Coulomb() + + if self.multi_decoders: + raise ValueError('LongRangeReadout cannot use multiple decoders') + + self.aggregator = get_aggregator('sum', 1, axis) + + self.gather_neighbors = GatherNeighbors(dim_feature, fixed_neigh) + self.squeeze = P.Squeeze(-1) + self.reduce_sum = P.ReduceSum() + self.keep_sum = P.ReduceSum(keep_dims=True) + + self.smooth_reciprocal = SmoothReciprocal() + + if cutoff_function is not None: + self.cutoff_function = get_cutoff( + cutoff_function, + r_max=cutoff_max, + r_min=cutoff_min, + return_mask=False, + reverse=True) + else: + self.cutoff_function = None + + def set_fixed_neighbors(self, flag=True): + self.fixed_neigh = flag + self.gather_neighbors.fixed_neigh = flag + + def _print_plus_info(self): + print('------calculate long range interaction: Yes') + print('---------with method for long range interaction: ' + self.name) + print('---------with coulomb constant: ' + str(self.coulomb_const * 2) + ' ' + self.unit_energy + '*' + + self.unit_dis) + + +class CoulombReadout(LongeRangeReadout): + """ + Predicts atom-wise contributions and accumulates global prediction, e.g. for the + energy. + + Args: + n_in (int): input dimension of representation + n_out (int): output dimension of target property (default: 1) + aggregation_mode (str): one of {sum, avg} (default: sum) + n_layers (int): number of nn in output network (default: 2) + n_neurons (list of int or None): number of neurons in each layer of the output + network. If `None`, divide neurons by 2 in each layer. (default: None) + activation (function): activation function for hidden nn + (default: spk.nn.activations.shifted_softplus) + mean (torch.Tensor or None): mean of property + stddev (torch.Tensor or None): standard deviation of property (default: None) + atomref (torch.Tensor or None): reference single-atom properties. Expects + an (max_z + 1) x 1 array where atomref[Z] corresponds to the reference + property of element Z. The value of atomref[0] must be zero, as this + corresponds to the reference property for for "mask" atoms. (default: None) + outnet (callable): Network used for atomistic outputs. Takes schnetpack input + dictionary as input. Output is not normalized. If set to None, + a pyramidal network is generated automatically. (default: None) + + Returns: + tuple: prediction for property + + If contributions is not None additionally returns atom-wise contributions. + + If derivative is not None additionally returns derivative w.r.t. atom positions. + + """ + + def __init__( + self, + dim_feature, + atom_scale=1, + atom_shift=0, + mol_scale=1, + mol_shift=0, + axis=-2, + atom_ref=None, + activation=None, + decoder='halve', + longrange_decoder=None, + unit_energy='kcal/mol', + cutoff_function='gaussian', + cutoff_max=units.length(1, 'nm'), + cutoff_min=units.length(0.8, 'nm'), + fixed_neigh=False, + read_all_interactions=False, + n_interactions=None, + interactions_aggregator='sum', + n_aggregator_hiddens=0, + interaction_decoders=None, + ): + super().__init__( + dim_feature=dim_feature, + atom_scale=atom_scale, + atom_shift=atom_shift, + mol_scale=mol_scale, + mol_shift=mol_shift, + axis=axis, + atom_ref=atom_ref, + activation=activation, + decoder=decoder, + longrange_decoder=longrange_decoder, + unit_energy=unit_energy, + cutoff_function=cutoff_function, + cutoff_max=cutoff_max, + cutoff_min=cutoff_min, + fixed_neigh=fixed_neigh, + read_all_interactions=read_all_interactions, + n_interactions=n_interactions, + interactions_aggregator=interactions_aggregator, + n_aggregator_hiddens=n_aggregator_hiddens, + interaction_decoders=interaction_decoders, + ) + self.name = 'Coulumb' + + self.decoder = get_decoder(decoder, dim_feature, 2, self.activation) + if self.decoder is None: + raise ValueError("Decoder in CoulombReadout cannot be None") + + if longrange_decoder is not None: + raise ValueError("CoulombReadout cannot support longrange_decoder") + + self.split = P.Split(-1, 2) + + def construct( + self, + x, + xlist, + atoms_types=None, + atom_mask=None, + atoms_number=None, + distances=None, + neighbors=None, + neighbor_mask=None): + r""" + predicts atomwise property + """ + + if self.read_all_interactions: + x = self.interactions_aggregator(xlist, atom_mask) + + y = self.decoder(x) + + # [B,A,2] -> [B,A,1] * 2 + (ei, qi) = self.split(y) + # [B,A,1] -> [B,A,N,1] + qij = self.gather_neighbors(qi, neighbors) + # [B,A,N,1] -> [B,A,N] + qij = self.squeeze(qij) + # [B,A,N] -> [B,A,N] * [B,A,1] + qiqj = qij * qi + + sij = self.smooth_reciprocal(distances, neighbor_mask) + if self.cutoff_function is not None: + sij = sij * self.cutoff_function(distances, neighbor_mask) + + eq = qiqj * sij + eq = self.reduce_sum(eq, -1) + eq = self.aggregator(eq, atom_mask, atoms_number) * \ + self.coulomb_const / 2. + + ei = ei * self.atom_scale + self.atom_shift + + if self.atom_ref is not None: + ei += F.gather(self.atom_ref, atoms_types, 0) + + ei = self.aggregator(ei, atom_mask, atoms_number) + + ei = ei * self.mol_scale + self.mol_shift + + return ei + eq + + +class PairwiseReadout(LongeRangeReadout): + """PairwiseReadout""" + def __init__(self, + dim_feature, + atom_scale=1, + atom_shift=0, + mol_scale=1, + mol_shift=0, + axis=-2, + atom_ref=None, + activation=None, + decoder='halve', + longrange_decoder=None, + unit_energy='kcal/mol', + cutoff_function='gaussian', + cutoff_max=units.length(1, 'nm'), + cutoff_min=units.length(0.8, 'nm'), + fixed_neigh=False, + read_all_interactions=False, + n_interactions=None, + interactions_aggregator='sum', + n_aggregator_hiddens=0, + interaction_decoders=None, + ): + super().__init__( + dim_feature=dim_feature, + atom_scale=atom_scale, + atom_shift=atom_shift, + mol_scale=mol_scale, + mol_shift=mol_shift, + axis=axis, + atom_ref=atom_ref, + activation=activation, + decoder=decoder, + longrange_decoder=longrange_decoder, + unit_energy=unit_energy, + cutoff_function=cutoff_function, + cutoff_max=cutoff_max, + cutoff_min=cutoff_min, + fixed_neigh=fixed_neigh, + read_all_interactions=read_all_interactions, + n_interactions=n_interactions, + interactions_aggregator=interactions_aggregator, + n_aggregator_hiddens=n_aggregator_hiddens, + interaction_decoders=interaction_decoders, + ) + + self.name = 'pairwise' + + self.decoder = get_decoder(decoder, dim_feature, 1, self.activation) + if self.decoder is None: + raise ValueError("Decoder in CoulombReadout cannot be None") + + self.longrange_decoder = get_decoder( + decoder, dim_feature, 1, self.activation) + if self.longrange_decoder is None: + raise ValueError( + "longrange_decoder in CoulombReadout cannot be None") + + self.squeeze = P.Squeeze(-1) + + def construct( + self, + x, + xlist, + atoms_types=None, + atom_mask=None, + atoms_number=None, + distances=None, + neighbors=None, + neighbor_mask=None): + """construct""" + + if self.read_all_interactions: + x = self.interactions_aggregator(xlist, atom_mask) + + ei = self.decoder(x) + + # [B,A,V] -> [B,A,1,V] + xi = F.expand_dims(x, -2) + # [B,A,N,V] + xij = self.gather_neighbors(x, neighbors) + # [B,A,N,V] = [B,A,N,V] * [B,A,1,V] + xixj = xij * xi + + # [B,A,N,1] + qiqj = self.longrange_decoder(xixj) + + qiqj = self.squeeze(qiqj) + sij = self.smooth_reciprocal(distances, neighbor_mask) + + if self.cutoff_function is not None: + cij = self.cutoff_function(distances, neighbor_mask) + sij = sij * cij + + eq = qiqj * sij + eq = self.reduce_sum(eq, -1) + eq = self.aggregator(eq, atom_mask, atoms_number) * \ + self.coulomb_const / 2. + + ei = ei * self.atom_scale + self.atom_shift + + if self.atom_ref is not None: + ei += F.gather(self.atom_ref, atoms_types, 0) + + ei = self.aggregator(ei, atom_mask, atoms_number) + + ei = ei * self.mol_scale + self.mol_shift + + return ei + eq diff --git a/MindSPONGE/mindsponge/md/cybertron/train.py b/MindSPONGE/mindsponge/md/cybertron/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf8d5aefaa3406b2fcee762602ee4812452666d --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/train.py @@ -0,0 +1,1186 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train""" + +import os +from shutil import copyfile +from collections import deque + +import numpy as np +import mindspore as ms +from mindspore import nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.train.callback import Callback +from mindspore.train.serialization import save_checkpoint +from mindspore.nn.metrics import Metric +from mindspore.train._utils import _make_directory +from mindspore.nn.learning_rate_schedule import LearningRateSchedule +from mindspore._checkparam import Validator as validator + +_cur_dir = os.getcwd() + +__all__ = [ + "DatasetWhitening", + "OutputScaleShift", + "WithForceLossCell", + "WithLabelLossCell", + "WithForceEvalCell", + "WithLabelEvalCell", + "TrainMonitor", + "MAE", + "MSE", + "MLoss", + "TransformerLR", +] + + +class DatasetWhitening(nn.Cell): + """DatasetWhitening""" + def __init__(self, + mol_scale=1, + mol_shift=0, + atom_scale=1, + atom_shift=0, + atom_ref=None, + axis=-2, + ): + super().__init__() + + self.mol_scale = mol_scale + self.mol_shift = mol_shift + + self.atom_scale = atom_scale + self.atom_shift = atom_shift + + self.atom_ref = atom_ref + self.axis = axis + + self.reduce_sum = P.ReduceSum() + self.keep_sum = P.ReduceSum(keep_dims=True) + + def construct(self, label, atoms_number, atom_types=None): + """construct""" + ref = 0 + if self.atom_ref is not None: + ref = F.gather(self.atom_ref, atom_types, 0) + ref = self.reduce_sum(ref, self.axis) + + whiten_label = (label - self.mol_shift) / self.mol_scale + whiten_label = (whiten_label - ref - self.atom_shift * + atoms_number) / self.atom_scale + + return whiten_label + + +class OutputScaleShift(nn.Cell): + """OutputScaleShift""" + def __init__(self, + mol_scale=1, + mol_shift=0, + atom_scale=1, + atom_shift=0, + atom_ref=None, + axis=-2, + ): + super().__init__() + + self.mol_scale = mol_scale + self.mol_shift = mol_shift + + self.atom_scale = atom_scale + self.atom_shift = atom_shift + + self.atom_ref = atom_ref + self.axis = axis + + self.reduce_sum = P.ReduceSum() + self.keep_sum = P.ReduceSum(keep_dims=True) + + def construct(self, outputs, atoms_number, atom_types=None): + """construct""" + ref = 0 + if self.atom_ref is not None: + ref = F.gather(self.atom_ref, atom_types, 0) + ref = self.reduce_sum(ref, self.axis) + + scaled_outputs = outputs * self.atom_scale + \ + self.atom_shift * atoms_number + ref + scaled_outputs = scaled_outputs * self.mol_scale + self.mol_shift + + return scaled_outputs + + +class LossWithEnergyAndForces(nn.loss.loss.LossBase): + """LossWithEnergyAndForces""" + def __init__(self, + ratio_energy=1, + ratio_forces=100, + force_aggregate='sum', + reduction='mean', + scale_dis=1, + ratio_normlize=True, + ): + super().__init__(reduction) + + if force_aggregate not in ('mean', 'sum'): + raise ValueError( + f"reduction_mol method for {force_aggregate} is not supported") + self.force_aggregate = force_aggregate + + self.scale_dis = scale_dis + self.ratio_normlize = ratio_normlize + + self.ratio_energy = ratio_energy + self.ratio_forces = ratio_forces + + self.norm = 1 + if self.ratio_normlize: + self.norm = ratio_energy + ratio_forces + + self.reduce_mean = P.ReduceMean() + self.reduce_sum = P.ReduceSum() + + def _calc_loss(self, diff): + return diff + + def construct( + self, + pred_energy, + label_energy, + pred_forces=None, + label_forces=None, + atoms_number=1, + atom_mask=None): + """construct""" + + if pred_forces is None: + loss = self._calc_loss(pred_energy - label_energy) + return self.get_loss(loss) + + eloss = 0 + if self.ratio_forces > 0: + ediff = (pred_energy - label_energy) / atoms_number + eloss = self._calc_loss(ediff) + + floss = 0 + if self.ratio_forces > 0: + fdiff = (pred_forces - label_forces) * self.scale_dis + fdiff = self._calc_loss(fdiff) + if self.force_aggregate == 'mean': + fdiff = self.reduce_mean(fdiff, -1) + else: + fdiff = self.reduce_sum(fdiff, -1) + + if atom_mask is None: + floss = self.reduce_mean(fdiff, -1) + else: + fdiff = fdiff * atom_mask + floss = self.reduce_sum(fdiff, -1) + floss = floss / atoms_number + + y = (eloss * self.ratio_energy + floss * self.ratio_forces) / self.norm + + natoms = F.cast(atoms_number, pred_energy.dtype) + weights = natoms / self.reduce_mean(natoms) + + return self.get_loss(y, weights) + + +class MAELoss(LossWithEnergyAndForces): + """MAELoss""" + def __init__(self, + ratio_energy=1, + ratio_forces=0, + force_aggregate='sum', + reduction='mean', + scale_dis=1, + ratio_normlize=True, + ): + super().__init__( + ratio_energy=ratio_energy, + ratio_forces=ratio_forces, + force_aggregate=force_aggregate, + reduction=reduction, + scale_dis=scale_dis, + ratio_normlize=ratio_normlize, + ) + self.abs = P.Abs() + + def _calc_loss(self, diff): + return self.abs(diff) + + +class MSELoss(LossWithEnergyAndForces): + """MSELoss""" + def __init__(self, + ratio_energy=1, + ratio_forces=0, + force_aggregate='sum', + reduction='mean', + scale_dis=1, + ratio_normlize=True, + ): + super().__init__( + ratio_energy=ratio_energy, + ratio_forces=ratio_forces, + force_aggregate=force_aggregate, + reduction=reduction, + scale_dis=scale_dis, + ratio_normlize=ratio_normlize, + ) + self.square = P.Square() + + def _calc_loss(self, diff): + return self.square(diff) + + +class WithCell(nn.Cell): + """WithCell""" + def __init__(self, datatypes,): + super().__init__(auto_prefix=False) + + self.fulltypes = 'RZCNnBbLlE' + self.datatypes = datatypes + + self.r = -1 # positions + self.z = -1 # atom_types + self.c = -1 # pbcbox + self.n = -1 # neighbors + self.n_mask = -1 # neighbor_mask + self.b = -1 # bonds + self.b_mask = -1 # bond_mask + self.l = -1 # far_neighbors + self.l_mask = -1 # far_mask + self.e = -1 # energy + + def _find_type_indexes(self, datatypes): + """_find_type_indexes""" + if not isinstance(datatypes, str): + raise TypeError('Type of "datatypes" must be str') + + for datatype in datatypes: + if self.fulltypes.count(datatype) == 0: + raise ValueError('Unknown datatype: ' + datatype) + + for datatype in self.fulltypes: + num = datatypes.count(datatype) + if num > 1: + raise ValueError( + 'There are ' + + str(num) + + ' "' + + datatype + + '" in datatype "' + + datatypes + + '".') + + self.r = datatypes.find('R') # positions + self.z = datatypes.find('Z') # atom_types + self.c = datatypes.find('C') # pbcbox + self.n = datatypes.find('N') # neighbors + self.n_mask = datatypes.find('n') # neighbor_mask + self.b = datatypes.find('B') # bonds + self.b_mask = datatypes.find('b') # bond_mask + self.l = datatypes.find('L') # far_neighbors + self.l_mask = datatypes.find('l') # far_mask + self.e = datatypes.find('E') # energy + + if self.e < 0: + raise TypeError('The datatype "E" must be included!') + + self.keep_sum = P.ReduceSum(keep_dims=True) + + +class WithForceLossCell(WithCell): + """WithForceLossCell""" + def __init__(self, + datatypes, + backbone, + loss_fn, + do_whitening=False, + mol_scale=1, + mol_shift=0, + atom_scale=1, + atom_shift=0, + atom_ref=None, + ): + super().__init__(datatypes=datatypes) + + self.scale = mol_scale * atom_scale + self.do_whitening = do_whitening + if do_whitening: + self.whitening = DatasetWhitening( + mol_scale=mol_scale, + mol_shift=mol_shift, + atom_scale=atom_scale, + atom_shift=atom_shift, + atom_ref=atom_ref, + ) + else: + self.whitening = None + + self.fulltypes = 'RZCNnBbLlFE' + self._find_type_indexes(datatypes) + self.f = datatypes.find('F') # force + + if self.f < 0: + raise TypeError( + 'The datatype "F" must be included in WithForceLossCell!') + + self._backbone = backbone + self._loss_fn = loss_fn + + self.atom_types = self._backbone.atom_types + + self.grad_op = C.GradOperation() + + def construct(self, *inputs): + """construct""" + inputs = inputs + (None,) + + positions = inputs[self.r] + atom_types = inputs[self.z] + pbcbox = inputs[self.c] + neighbors = inputs[self.n] + neighbor_mask = inputs[self.n_mask] + bonds = inputs[self.b] + bond_mask = inputs[self.b_mask] + far_neighbors = inputs[self.l] + far_mask = inputs[self.l_mask] + + energy = inputs[self.e] + out = self._backbone( + positions, + atom_types, + pbcbox, + neighbors, + neighbor_mask, + bonds, + bond_mask, + far_neighbors, + far_mask, + ) + + forces = inputs[self.f] + fout = -1 * self.grad_op(self._backbone)( + positions, + atom_types, + pbcbox, + neighbors, + neighbor_mask, + bonds, + bond_mask, + far_neighbors, + far_mask, + ) + + if atom_types is None: + atom_types = self.atom_types + + atoms_number = F.cast(atom_types > 0, out.dtype) + atoms_number = self.keep_sum(atoms_number, -1) + + if self.do_whitening: + energy = self.whitening(energy, atoms_number, atom_types) + forces /= self.scale + + if atom_types is None: + return self._loss_fn(out, energy, fout, forces) + atom_mask = atom_types > 0 + return self._loss_fn( + out, + energy, + fout, + forces, + atoms_number, + atom_mask) + + @property + def backbone_network(self): + return self._backbone + + +class WithLabelLossCell(WithCell): + """WithLabelLossCell""" + def __init__(self, + datatypes, + backbone, + loss_fn, + do_whitening=False, + mol_scale=1, + mol_shift=0, + atom_scale=1, + atom_shift=0, + atom_ref=None, + # with_penalty=False, + ): + super().__init__(datatypes=datatypes) + self._backbone = backbone + self._loss_fn = loss_fn + # self.with_penalty = with_penalty + + self.atom_types = self._backbone.atom_types + + self.do_whitening = do_whitening + if do_whitening: + self.whitening = DatasetWhitening( + mol_scale=mol_scale, + mol_shift=mol_shift, + atom_scale=atom_scale, + atom_shift=atom_shift, + atom_ref=atom_ref, + ) + else: + self.whitening = None + + self._find_type_indexes(datatypes) + + def construct(self, *inputs): + """construct""" + + inputs = inputs + (None,) + + positions = inputs[self.r] + atom_types = inputs[self.z] + pbcbox = inputs[self.c] + neighbors = inputs[self.n] + neighbor_mask = inputs[self.n_mask] + bonds = inputs[self.b] + bond_mask = inputs[self.b_mask] + far_neighbors = inputs[self.l] + far_mask = inputs[self.l_mask] + + out = self._backbone( + positions, + atom_types, + pbcbox, + neighbors, + neighbor_mask, + bonds, + bond_mask, + far_neighbors, + far_mask, + ) + + label = inputs[self.e] + + if atom_types is None: + atom_types = self.atom_types + + atoms_number = F.cast(atom_types > 0, out.dtype) + atoms_number = self.keep_sum(atoms_number, -1) + + if self.do_whitening: + label = self.whitening(label, atoms_number, atom_types) + + return self._loss_fn(out, label) + + +class WithForceEvalCell(WithCell): + """WithForceEvalCell""" + def __init__(self, + datatypes, + network, + loss_fn=None, + add_cast_fp32=False, + do_whitening=False, + mol_scale=1, + mol_shift=0, + atom_scale=1, + atom_shift=0, + atom_ref=None, + ): + super().__init__(datatypes) + + self.scale = mol_scale * atom_scale + self.do_whitening = do_whitening + self.scaleshift = None + self.whitening = None + if do_whitening: + self.scaleshift = OutputScaleShift( + mol_scale=mol_scale, + mol_shift=mol_shift, + atom_scale=atom_scale, + atom_shift=atom_shift, + atom_ref=atom_ref, + ) + if loss_fn is not None: + self.whitening = DatasetWhitening( + mol_scale=mol_scale, + mol_shift=mol_shift, + atom_scale=atom_scale, + atom_shift=atom_shift, + atom_ref=atom_ref, + ) + + self.fulltypes = 'RZCNnBbLlFE' + self._find_type_indexes(datatypes) + self.f = datatypes.find('F') # force + + if self.f < 0: + raise TypeError( + 'The datatype "F" must be included in WithForceEvalCell!') + + self._network = network + self._loss_fn = loss_fn + self.add_cast_fp32 = add_cast_fp32 + + self.atom_types = self._network.atom_types + + self.reduce_sum = P.ReduceSum() + + self.grad_op = C.GradOperation() + + def construct(self, *inputs): + """construct""" + inputs = inputs + (None,) + + positions = inputs[self.r] + atom_types = inputs[self.z] + pbcbox = inputs[self.c] + neighbors = inputs[self.n] + neighbor_mask = inputs[self.n_mask] + bonds = inputs[self.b] + bond_mask = inputs[self.b_mask] + far_neighbors = inputs[self.l] + far_mask = inputs[self.l_mask] + + outputs = self._network( + positions, + atom_types, + pbcbox, + neighbors, + neighbor_mask, + bonds, + bond_mask, + far_neighbors, + far_mask, + ) + + foutputs = -1 * self.grad_op(self._network)( + positions, + atom_types, + pbcbox, + neighbors, + neighbor_mask, + bonds, + bond_mask, + far_neighbors, + far_mask, + ) + + forces = inputs[self.f] + energy = inputs[self.e] + + if self.add_cast_fp32: + forces = F.mixed_precision_cast(ms.float32, forces) + energy = F.mixed_precision_cast(ms.float32, energy) + outputs = F.cast(outputs, ms.float32) + + if atom_types is None: + atom_types = self.atom_types + + atoms_number = F.cast(atom_types > 0, outputs.dtype) + atoms_number = self.keep_sum(atoms_number, -1) + + loss = 0 + if self._loss_fn is not None: + energy_t = energy + forces_t = forces + if self.do_whitening: + energy_t = self.whitening(energy_t, atoms_number, atom_types) + forces_t /= self.scale + + atom_mask = atom_types > 0 + loss = self._loss_fn( + outputs, + energy_t, + foutputs, + forces_t, + atoms_number, + atom_mask) + + if self.do_whitening: + outputs = self.scaleshift(outputs, atoms_number, atom_types) + foutputs *= self.scale + + return loss, outputs, energy, foutputs, forces, atoms_number + + +class WithLabelEvalCell(WithCell): + """WithLabelEvalCell""" + def __init__(self, + datatypes, + network, + loss_fn=None, + add_cast_fp32=False, + do_whitening=False, + mol_scale=1, + mol_shift=0, + atom_scale=1, + atom_shift=0, + atom_ref=None, + ): + super().__init__(datatypes=datatypes) + self._network = network + self._loss_fn = loss_fn + self.add_cast_fp32 = add_cast_fp32 + self.reducesum = P.ReduceSum(keep_dims=True) + + self.atom_types = self._network.atom_types + + self.do_whitening = do_whitening + self.scaleshift = None + self.whitening = None + if do_whitening: + self.scaleshift = OutputScaleShift( + mol_scale=mol_scale, + mol_shift=mol_shift, + atom_scale=atom_scale, + atom_shift=atom_shift, + atom_ref=atom_ref, + ) + if loss_fn is not None: + self.whitening = DatasetWhitening( + mol_scale=mol_scale, + mol_shift=mol_shift, + atom_scale=atom_scale, + atom_shift=atom_shift, + atom_ref=atom_ref, + ) + + self._find_type_indexes(datatypes) + + def construct(self, *inputs): + """construct""" + inputs = inputs + (None,) + + positions = inputs[self.r] + atom_types = inputs[self.z] + pbcbox = inputs[self.c] + neighbors = inputs[self.n] + neighbor_mask = inputs[self.n_mask] + bonds = inputs[self.b] + bond_mask = inputs[self.b_mask] + far_neighbors = inputs[self.l] + far_mask = inputs[self.l_mask] + + outputs = self._network( + positions, + atom_types, + pbcbox, + neighbors, + neighbor_mask, + bonds, + bond_mask, + far_neighbors, + far_mask, + ) + + label = inputs[self.e] + if self.add_cast_fp32: + label = F.mixed_precision_cast(ms.float32, label) + outputs = F.cast(outputs, ms.float32) + + if atom_types is None: + atom_types = self.atom_types + + atoms_number = F.cast(atom_types > 0, outputs.dtype) + atoms_number = self.keep_sum(atoms_number, -1) + + loss = 0 + if self._loss_fn is not None: + label_t = label + if self.do_whitening: + label_t = self.whitening(label, atoms_number, atom_types) + loss = self._loss_fn(outputs, label_t) + + if self.do_whitening: + outputs = self.scaleshift(outputs, atoms_number, atom_types) + + return loss, outputs, label, atoms_number + + +class TrainMonitor(Callback): + """TrainMonitor""" + def __init__( + self, + model, + name, + directory=None, + per_epoch=1, + per_step=0, + avg_steps=0, + eval_dataset=None, + best_ckpt_metrics=None): + super().__init__() + if not isinstance(per_epoch, int) or per_epoch < 0: + raise ValueError("per_epoch must be int and >= 0.") + if not isinstance(per_step, int) or per_step < 0: + raise ValueError("per_step must be int and >= 0.") + + self.avg_steps = avg_steps + self.loss_record = 0 + self.train_num = 0 + if avg_steps > 0: + self.train_num = deque(maxlen=avg_steps) + self.loss_record = deque(maxlen=avg_steps) + + if per_epoch * per_step != 0: + if per_epoch == 1: + per_epoch = 0 + else: + raise ValueError( + "per_epoch and per_step cannot larger than 0 at same time.") + self.model = model + self._per_epoch = per_epoch + self._per_step = per_step + self.eval_dataset = eval_dataset + + if directory is not None: + self._directory = _make_directory(directory) + else: + self._directory = _cur_dir + + self._filename = name + '-info.data' + self._ckptfile = name + '-best' + self._ckptdata = name + '-cpkt.data' + + self.num_ckpt = 1 + self.best_value = 5e4 + self.best_ckpt_metrics = best_ckpt_metrics + + self.last_loss = 0 + self.record = [] + + self.output_title = True + filename = os.path.join(self._directory, self._filename) + if os.path.exists(filename): + with open(filename, "r") as f: + lines = f.readlines() + if len(lines) > 1: + os.remove(filename) + + def _write_cpkt_file(self, filename, info, network): + ckptfile = os.path.join(self._directory, filename + '.ckpt') + ckptbck = os.path.join(self._directory, filename + '.bck.ckpt') + ckptdata = os.path.join(self._directory, self._ckptdata) + + if os.path.exists(ckptfile): + os.rename(ckptfile, ckptbck) + save_checkpoint(network, ckptfile) + with open(ckptdata, "a") as f: + f.write(info + os.linesep) + + def _output_data(self, cb_params): + """_output_data""" + cur_epoch = cb_params.cur_epoch_num + + opt = cb_params.optimizer + if opt is None: + opt = cb_params.train_network.optimizer + + if opt.dynamic_lr: + step = opt.global_step + if not isinstance(step, int): + step = step.asnumpy()[0] + else: + step = cb_params.cur_step_num + + if self.avg_steps > 0: + mov_avg = sum(self.loss_record) / sum(self.train_num) + else: + mov_avg = self.loss_record / self.train_num + + title = "#! FIELDS step" + info = 'Epoch: ' + str(cur_epoch) + ', Step: ' + str(step) + outdata = '{:>10d}'.format(step) + + lr = opt.learning_rate + if opt.dynamic_lr: + step = F.cast(step, ms.int32) + if opt.is_group_lr: + lr = () + for learning_rate in opt.learning_rate: + current_dynamic_lr = learning_rate(step - 1) + lr += (current_dynamic_lr,) + else: + lr = opt.learning_rate(step - 1) + lr = lr.asnumpy() + + title += ' learning_rate' + info += ', Learning_rate: ' + str(lr) + outdata += '{:>15e}'.format(lr) + + title += " last_loss avg_loss" + info += ', Last_Loss: ' + \ + str(self.last_loss) + ', Avg_loss: ' + str(mov_avg) + outdata += '{:>15e}'.format(self.last_loss) + '{:>15e}'.format(mov_avg) + + _make_directory(self._directory) + + if self.eval_dataset is not None: + eval_metrics = self.model.eval( + self.eval_dataset, dataset_sink_mode=False) + for k, v in eval_metrics.items(): + info += ', ' + info += k + info += ': ' + info += str(v) + + if isinstance(v, np.ndarray) and v.size > 1: + for i in range(v.size): + title += (' ' + k + str(i)) + outdata += '{:>15e}'.format(v[i]) + else: + title += (' ' + k) + outdata += '{:>15e}'.format(v) + if self.best_ckpt_metrics in eval_metrics.keys(): + self.eval_dataset_process(eval_metrics, info, cb_params) + + + print(info, flush=True) + filename = os.path.join(self._directory, self._filename) + if self.output_title: + with open(filename, "a") as f: + f.write(title + os.linesep) + self.output_title = False + with open(filename, "a") as f: + f.write(outdata + os.linesep) + + def eval_dataset_process(self, eval_metrics, info, cb_params): + """eval_dataset_process""" + vnow = eval_metrics[self.best_ckpt_metrics] + if isinstance(vnow, np.ndarray) and len(vnow) > 1: + output_ckpt = vnow < self.best_value + num_best = np.count_nonzero(output_ckpt) + if num_best > 0: + self._write_cpkt_file( + self._ckptfile, info, cb_params.train_network) + source_ckpt = os.path.join( + self._directory, self._ckptfile + '.ckpt') + for i in range(len(vnow)): + if output_ckpt[i]: + dest_ckpt = os.path.join( + self._directory, self._ckptfile + '-' + str(i) + '.ckpt') + bck_ckpt = os.path.join( + self._directory, self._ckptfile + '-' + str(i) + '.ckpt.bck') + if os.path.exists(dest_ckpt): + os.rename(dest_ckpt, bck_ckpt) + copyfile(source_ckpt, dest_ckpt) + self.best_value = np.minimum(vnow, self.best_value) + else: + if vnow < self.best_value: + self._write_cpkt_file( + self._ckptfile, info, cb_params.train_network) + self.best_value = vnow + + def step_end(self, run_context): + """step_end""" + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + nbatch = len(cb_params.train_dataset_element[0]) + batch_loss = loss * nbatch + + self.last_loss = loss + if self.avg_steps > 0: + self.loss_record.append(batch_loss) + self.train_num.append(nbatch) + else: + self.loss_record += batch_loss + self.train_num += nbatch + + if self._per_step > 0 and cb_params.cur_step_num % self._per_step == 0: + self._output_data(cb_params) + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + + if self._per_epoch > 0 and cur_epoch % self._per_epoch == 0: + self._output_data(cb_params) + + +class MaxError(Metric): + """MaxError""" + def __init__(self, indexes=None, reduce_all_dims=True): + super().__init__() + self.clear() + self._indexes = [1, 2] if not indexes else indexes + if reduce_all_dims: + self.axis = None + else: + self.axis = 0 + + def clear(self): + self._max_error = 0 + + def update(self, *inputs): + y_pred = self._convert_data(inputs[self._indexes[0]]) + y = self._convert_data(inputs[self._indexes[1]]) + diff = y.reshape(y_pred.shape) - y_pred + max_error = diff.max() - diff.min() + if max_error > self._max_error: + self._max_error = max_error + + def eval(self): + return self._max_error + + +class Error(Metric): + """Error""" + def __init__(self, + indexes=None, + reduce_all_dims=True, + averaged_by_atoms=False, + atom_aggregate='mean', + ): + super().__init__() + self.clear() + self._indexes = [1, 2] if not indexes else indexes + self.read_atoms_number = False + if len(self._indexes) > 2: + self.read_atoms_number = True + + self.reduce_all_dims = reduce_all_dims + + if atom_aggregate.lower() not in ('mean', 'sum'): + raise ValueError( + 'aggregate_by_atoms method must be "mean" or "sum"') + self.atom_aggregate = atom_aggregate.lower() + + if reduce_all_dims: + self.axis = None + else: + self.axis = 0 + + if averaged_by_atoms and not self.read_atoms_number: + raise ValueError( + 'When to use averaged_by_atoms, the index of atom number must be set at "indexes".') + + self.averaged_by_atoms = averaged_by_atoms + + self._error_sum = 0 + self._samples_num = 0 + + def clear(self): + self._error_sum = 0 + self._samples_num = 0 + + def _calc_error(self, y, y_pred): + return y.reshape(y_pred.shape) - y_pred + + def update(self, *inputs): + """update""" + y_pred = self._convert_data(inputs[self._indexes[0]]) + y = self._convert_data(inputs[self._indexes[1]]) + + error = self._calc_error(y, y_pred) + if len(error.shape) > 2: + axis = tuple(range(2, len(error.shape))) + if self.atom_aggregate == 'mean': + error = np.mean(error, axis=axis) + else: + error = np.sum(error, axis=axis) + + tot = y.shape[0] + if self.read_atoms_number: + natoms = self._convert_data(inputs[self._indexes[2]]) + if self.averaged_by_atoms: + error /= natoms + elif self.reduce_all_dims: + tot = np.sum(natoms) + if natoms.shape[0] != y.shape[0]: + tot *= y.shape[0] + elif self.reduce_all_dims: + tot = error.size + + self._error_sum += np.sum(error, axis=self.axis) + self._samples_num += tot + + def eval(self): + if self._samples_num == 0: + raise RuntimeError('Total samples num must not be 0.') + return self._error_sum / self._samples_num + +# mean absolute error + + +class MAE(Error): + """MAE""" + def __init__(self, + indexes=None, + reduce_all_dims=True, + averaged_by_atoms=False, + atom_aggregate='mean', + ): + super().__init__( + indexes=indexes, + reduce_all_dims=reduce_all_dims, + averaged_by_atoms=averaged_by_atoms, + atom_aggregate=atom_aggregate, + ) + + def _calc_error(self, y, y_pred): + return np.abs(y.reshape(y_pred.shape) - y_pred) + +# mean square error + + +class MSE(Error): + """MSE""" + def __init__(self, + indexes=None, + reduce_all_dims=True, + averaged_by_atoms=False, + atom_aggregate='mean', + ): + super().__init__( + indexes=indexes, + reduce_all_dims=reduce_all_dims, + averaged_by_atoms=averaged_by_atoms, + atom_aggregate=atom_aggregate, + ) + + def _calc_error(self, y, y_pred): + return np.square(y.reshape(y_pred.shape) - y_pred) + +# mean norm error + + +class MNE(Error): + """MNE""" + def __init__(self, + indexes=None, + reduce_all_dims=True, + averaged_by_atoms=False, + atom_aggregate='mean', + ): + super().__init__( + indexes=indexes, + reduce_all_dims=reduce_all_dims, + averaged_by_atoms=averaged_by_atoms, + atom_aggregate=atom_aggregate, + ) + + def _calc_error(self, y, y_pred): + diff = y.reshape(y_pred.shape) - y_pred + return np.linalg.norm(diff, axis=-1) + +# root mean square error + + +class RMSE(Error): + """RMSE""" + def __init__(self, + indexes=None, + reduce_all_dims=True, + averaged_by_atoms=False, + atom_aggregate='mean', + ): + super().__init__( + indexes=indexes, + reduce_all_dims=reduce_all_dims, + averaged_by_atoms=averaged_by_atoms, + atom_aggregate=atom_aggregate, + ) + + def _calc_error(self, y, y_pred): + return np.square(y.reshape(y_pred.shape) - y_pred) + + def eval(self): + if self._samples_num == 0: + raise RuntimeError('Total samples num must not be 0.') + return np.sqrt(self._error_sum / self._samples_num) + + +class MLoss(Metric): + """MLoss""" + def __init__(self, index=0): + super().__init__() + self.clear() + self._index = index + + def clear(self): + self._sum_loss = 0 + self._total_num = 0 + + def update(self, *inputs): + """update""" + + loss = self._convert_data(inputs[self._index]) + + if loss.ndim == 0: + loss = loss.reshape(1) + + if loss.ndim != 1: + raise ValueError( + "Dimensions of loss must be 1, but got {}".format( + loss.ndim)) + + loss = loss.mean(-1) + self._sum_loss += loss + self._total_num += 1 + + def eval(self): + if self._total_num == 0: + raise RuntimeError('Total number can not be 0.') + return self._sum_loss / self._total_num + + +class TransformerLR(LearningRateSchedule): + """TransformerLR""" + def __init__(self, learning_rate=1.0, warmup_steps=4000, dimension=1): + super().__init__() + if not isinstance(learning_rate, float): + raise TypeError("learning_rate must be float.") + validator.check_non_negative_float( + learning_rate, "learning_rate", self.cls_name) + validator.check_positive_int( + warmup_steps, 'warmup_steps', self.cls_name) + + self.learning_rate = learning_rate + + self.pow = P.Pow() + self.warmup_scale = self.pow(F.cast(warmup_steps, ms.float32), -1.5) + self.dim_scale = self.pow(F.cast(dimension, ms.float32), -0.5) + + self.min = P.Minimum() + + def construct(self, global_step): + step_num = F.cast(global_step, ms.float32) + lr_percent = self.dim_scale * \ + self.min(self.pow(step_num, -0.5), step_num * self.warmup_scale) + return self.learning_rate * lr_percent diff --git a/MindSPONGE/mindsponge/md/cybertron/units.py b/MindSPONGE/mindsponge/md/cybertron/units.py new file mode 100644 index 0000000000000000000000000000000000000000..af01e4f0b5264f8b510d8136b5c6fc490d7509f0 --- /dev/null +++ b/MindSPONGE/mindsponge/md/cybertron/units.py @@ -0,0 +1,209 @@ +# ============================================================================ +# Copyright 2021 The AIMM team at Shenzhen Bay Laboratory & Peking University +# +# People: Yi Isaac Yang, Jun Zhang, Diqing Chen, Yaqiang Zhou, Huiyang Zhang, +# Yupeng Huang, Yijie Xia, Yao-Kun Lei, Lijiang Yang, Yi Qin Gao +# +# This code is a part of Cybertron-Code package. +# +# The Cybertron-Code is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""utils""" + +import mindspore as ms + + +class Units: + """Units""" + def __init__(self): + + self.float = ms.float32 + self.int = ms.int32 + + # length + self._length_def = 1.0 + self._length = self._length_def + + self._nm = self._length_def + self._um = self._nm * 1e3 + self._angstrom = self._nm * 0.1 + self._bohr = self._nm * 0.052917721067 + + self._length_dict = { + 'nm': self._nm, + 'um': self._um, + 'a': self._angstrom, + 'angstrom': self._angstrom, + 'bohr': self._bohr, + } + + self._length_name = { + 'nm': 'nm', + 'um': 'um', + 'a': 'Angstrom', + 'angstrom': 'Angstrom', + 'bohr': 'Bohr', + } + + self._length_unit_def = 'nm' + self._length_unit = self._length_unit_def + + # energy + self._energy_def = 1.0 + self._energy = self._energy_def + + self._kj_mol = self._energy_def + self._j_mol = self._kj_mol * 1e-3 + self._kcal_mol = self._kj_mol * 4.184 + self._cal_mol = self._kj_mol * 4.184e-3 + self._hartree = self._kj_mol * 2625.5002 + self._ev = self._kj_mol * 96.48530749925793 + + self._energy_dict = { + 'kj/mol': self._kj_mol, + 'j/mol': self._j_mol, + 'kcal/mol': self._kcal_mol, + 'cal/mol': self._cal_mol, + 'ha': self._hartree, + 'hartree': self._hartree, + 'ev': self._ev, + } + + self._energy_name = { + 'kj/mol': 'kJ/mol', + 'j/mol': 'J/mol', + 'kcal/mol': 'kcal/mol', + 'cal/mol': 'cal/mol', + 'ha': 'Hartree', + 'hartree': 'Hartree', + 'ev': 'eV', + } + + self._energy_unit_def = 'kj/mol' + self._energy_unit = self._energy_unit_def + + # origin constant + self._avogadro_number = 6.02214076e23 + self._boltzmann_constant = 1.380649e-23 + self._gas_constant = 8.31446261815324 + self._elementary_charge = 1.602176634e-19 + self._coulomb_constant = 8.9875517923e9 + + # Boltzmann constant + self._boltzmann_def = 8.31446261815324e-3 # kj/mol + self._boltzmann = self._boltzmann_def + + # Coulomb constant + self._coulomb_def = 138.93545764498226165718756672623 # kj/mol*nm + self._coulomb = self._coulomb_def + + def check_length_unit(self, unit): + if unit.lower() not in self._length_dict.keys(): + raise ValueError('length unit "' + unit + '" is not recorded!') + return self._length_name[unit.lower()] + + def check_energy_unit(self, unit): + if unit.lower() not in self._energy_dict.keys(): + raise ValueError('energy unit "' + unit + '" is not recorded!') + return self._energy_name[unit.lower()] + + def set_default(self): + self._length_unit = self._length_unit_def + self._length = self._length_def + self._energy_unit = self._energy_unit_def + self._energy = self._energy_def + self._coulomb = self._coulomb_def + self._boltzmann_def = self._boltzmann_def + + def set_length_unit(self, unit): + self.check_length_unit(unit.lower()) + if unit.lower() != self._length_unit: + self._length_unit = unit.lower() + self._length = self._length_dict[unit.lower()] + + self._coulomb = self._coulomb_def \ + * self.def_energy_convert_to(self._energy_unit) \ + * self.def_length_convert_to(self._length_unit) + + def set_energy_unit(self, unit): + self.check_energy_unit(unit.lower()) + if unit.lower() != self._energy_unit: + self._energy_unit = unit.lower() + self._energy = self._energy_dict[unit.lower()] + + self._boltzmann = self._boltzmann_def * \ + self.def_energy_convert_to(unit.lower()) + self._coulomb = self._coulomb_def \ + * self.def_energy_convert_to(self._energy_unit) \ + * self.def_length_convert_to(self._length_unit) + + def length(self, length, unit): + self.check_length_unit(unit.lower()) + return length * self._length_dict[unit.lower()] / self._length + + def energy(self, energy, unit): + self.check_energy_unit(unit.lower()) + return energy * self._energy_dict[unit.lower()] / self._energy + + def length_convert(self, unit_in, unit_out): + self.check_length_unit(unit_in.lower()) + self.check_length_unit(unit_out.lower()) + return self._length_dict[unit_in.lower()] / \ + self._length_dict[unit_out.lower()] + + def energy_convert(self, unit_in, unit_out): + self.check_energy_unit(unit_in.lower()) + self.check_energy_unit(unit_out.lower()) + return self._energy_dict[unit_in.lower()] / \ + self._energy_dict[unit_out.lower()] + + def length_convert_to(self, unit): + return self.length_convert(self._length_unit, unit.lower()) + + def energy_convert_to(self, unit): + return self.energy_convert(self._energy_unit, unit.lower()) + + def def_length_convert_to(self, unit): + return self.length_convert(self._length_unit_def, unit.lower()) + + def def_energy_convert_to(self, unit): + return self.energy_convert(self._energy_unit_def, unit.lower()) + + def length_convert_from(self, unit): + return self.length_convert(unit.lower(), self._length_unit) + + def energy_convert_from(self, unit): + return self.energy_convert(unit.lower(), self._energy_unit) + + def boltzmann_constant(self): + return self._boltzmann_constant + + def boltzmann(self, unit=None): + if unit is None: + return self._boltzmann + return self._boltzmann_def * \ + self.def_energy_convert_to(unit.lower()) + + def coulomb(self, energy_unit=None, length_unit=None): + if (energy_unit is None) and (length_unit is None): + return self._coulomb + + scale_energy = self.def_energy_convert_to(energy_unit.lower()) + scale_length = self.def_length_convert_to(length_unit.lower()) + return self._coulomb_def * scale_energy * scale_length + + +units = Units()