diff --git a/MindChemistry/applications/crystalflow/.gitignore b/MindChem/applications/crystalflow/.gitignore similarity index 100% rename from MindChemistry/applications/crystalflow/.gitignore rename to MindChem/applications/crystalflow/.gitignore diff --git a/MindChemistry/applications/diffcsp/README.md b/MindChem/applications/diffcsp/README.md similarity index 100% rename from MindChemistry/applications/diffcsp/README.md rename to MindChem/applications/diffcsp/README.md diff --git a/MindChemistry/applications/diffcsp/compute_metric.py b/MindChem/applications/diffcsp/compute_metric.py similarity index 100% rename from MindChemistry/applications/diffcsp/compute_metric.py rename to MindChem/applications/diffcsp/compute_metric.py diff --git a/MindChemistry/applications/diffcsp/config.yaml b/MindChem/applications/diffcsp/config.yaml similarity index 100% rename from MindChemistry/applications/diffcsp/config.yaml rename to MindChem/applications/diffcsp/config.yaml diff --git a/MindChemistry/applications/diffcsp/data/crysloader.py b/MindChem/applications/diffcsp/data/crysloader.py similarity index 100% rename from MindChemistry/applications/diffcsp/data/crysloader.py rename to MindChem/applications/diffcsp/data/crysloader.py diff --git a/MindChemistry/applications/diffcsp/data/data_utils.py b/MindChem/applications/diffcsp/data/data_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/data/data_utils.py rename to MindChem/applications/diffcsp/data/data_utils.py diff --git a/MindChemistry/applications/diffcsp/data/dataset.py b/MindChem/applications/diffcsp/data/dataset.py similarity index 100% rename from MindChemistry/applications/diffcsp/data/dataset.py rename to MindChem/applications/diffcsp/data/dataset.py diff --git a/MindChemistry/applications/diffcsp/evaluate.py b/MindChem/applications/diffcsp/evaluate.py similarity index 100% rename from MindChemistry/applications/diffcsp/evaluate.py rename to MindChem/applications/diffcsp/evaluate.py diff --git a/MindChemistry/applications/diffcsp/models/cspnet.py b/MindChem/applications/diffcsp/models/cspnet.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/cspnet.py rename to MindChem/applications/diffcsp/models/cspnet.py diff --git a/MindChemistry/applications/diffcsp/models/diff_utils.py b/MindChem/applications/diffcsp/models/diff_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/diff_utils.py rename to MindChem/applications/diffcsp/models/diff_utils.py diff --git a/MindChemistry/applications/diffcsp/models/diffusion.py b/MindChem/applications/diffcsp/models/diffusion.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/diffusion.py rename to MindChem/applications/diffcsp/models/diffusion.py diff --git a/MindChemistry/applications/diffcsp/models/infer_utils.py b/MindChem/applications/diffcsp/models/infer_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/infer_utils.py rename to MindChem/applications/diffcsp/models/infer_utils.py diff --git a/MindChemistry/applications/diffcsp/models/train_utils.py b/MindChem/applications/diffcsp/models/train_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/train_utils.py rename to MindChem/applications/diffcsp/models/train_utils.py diff --git a/MindChemistry/applications/diffcsp/requirement.txt b/MindChem/applications/diffcsp/requirement.txt similarity index 100% rename from MindChemistry/applications/diffcsp/requirement.txt rename to MindChem/applications/diffcsp/requirement.txt diff --git a/MindChemistry/applications/diffcsp/train.py b/MindChem/applications/diffcsp/train.py similarity index 100% rename from MindChemistry/applications/diffcsp/train.py rename to MindChem/applications/diffcsp/train.py diff --git a/mindscience/e3nn/__init__.py b/mindscience/e3nn/__init__.py index 69a14b29e1ced3fa627e5dada3f5f6ba239fdc1c..5ba0a5f6880df5a91a32252bd52fd919edda11a9 100644 --- a/mindscience/e3nn/__init__.py +++ b/mindscience/e3nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -init -""" +"""init for e3 module""" +from .o3 import * +from .nn import * +from .utils import * -__all__ = [] \ No newline at end of file +__all__ = [] +__all__.extend(o3.__all__) +__all__.extend(nn.__all__) +__all__.extend(utils.__all__) diff --git a/mindscience/e3nn/nn/__init__.py b/mindscience/e3nn/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63d292a9c21dd08499aa21e2e6f5c79007f1aa67 --- /dev/null +++ b/mindscience/e3nn/nn/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .activation import Activation +from .gate import Gate +from .fc import FullyConnectedNet +from .normact import NormActivation +from .scatter import Scatter +from .one_hot import SoftOneHotLinspace, soft_one_hot_linspace, soft_unit_step, OneHot +from .batchnorm import BatchNorm + +__all__ = [ + "Activation", + "Gate", + "FullyConnectedNet", + "NormActivation", + "Scatter", + "SoftOneHotLinspace", + "soft_one_hot_linspace", + "soft_unit_step", + "OneHot", + "BatchNorm" +] diff --git a/mindscience/e3nn/nn/activation.py b/mindscience/e3nn/nn/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..8553fd84e3e5c1e1044f4dbe2b8f8deab0b42c64 --- /dev/null +++ b/mindscience/e3nn/nn/activation.py @@ -0,0 +1,147 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""activation""" +import numpy as np + +from mindspore import Tensor, nn, ops, float32 +from ..o3.irreps import Irreps + +identity = ops.Identity() +NTOL = 1e-5 + + +def _moment(f, n, dtype=float32): + x = Tensor(np.random.randn(1000000), dtype=dtype) + y = f(x).pow(n).mean().pow(-0.5) + + return y + + +def _parity_function(f, dtype=float32): + x = Tensor(np.linspace(.0, 10., 256), dtype=dtype) + y1, y2 = f(x).asnumpy(), f(-x).asnumpy() + if np.max(np.abs(y1 - y2)) < NTOL: + return 1 + if np.max(np.abs(y1 + y2)) < NTOL: + return -1 + return 0 + + +class _Normalize(nn.Cell): + """_Normalize""" + + def __init__(self, f, dtype=float32): + super().__init__() + self.f = f + self.factor = _moment(f, 2, dtype) + if ops.abs(self.factor - 1.) < 1e-4: + self._is_id = True + else: + self._is_id = False + + def construct(self, x): + if self._is_id: + return self.f(x) + return self.f(x).mul(self.factor) + + +class Activation(nn.Cell): + r""" + Activation function for scalar-tensors. The parities of irreps may be changed according to the parity of each + activation functions. + Odd scalars require the corresponding activation functions to be odd or even. + + Args: + irreps_in (Union[str, Irrep, Irreps]): the input irreps. + acts (List[Func]): a list of activation functions for each part of `irreps_in`. + The length of the `acts` will be clipped or filled by identity functions to match the length of `irreps_in`. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + + Inputs: + - **inputs** (Tensor) - The shape of Tensor is :math:`(*, irreps\_in.dim)`. + + Outputs: + - **outputs** (Tensor) - The shape of Tensor is :math:`(*, irreps\_in.dim)`. + + Raises: + ValueError: If `irreps_in` contain non-scalar irrep. + ValueError: If a irrep in `irreps_in` is odd, but the corresponding activation function is neither even nor odd. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import Activation + >>> from mindspore import ops, Tensor + >>> act = Activation('3x0o+2x0e+1x0o', [ops.abs, ops.tanh]) + >>> print(act) + Activation [xx-] (3x0o+2x0e+1x0o -> 3x0e+2x0e+1x0o) + >>> inputs = Tensor(ops.ones((4,6))) + >>> outputs = act(inputs) + >>> print(outputs.shape) + (4, 6) + """ + + def __init__(self, irreps_in, acts, dtype=float32): + super().__init__() + irreps_in = Irreps(irreps_in) + while len(acts) < len(irreps_in): + acts.append(None) + irreps_out = [] + acts_out = [] + for (mul, (l_in, p_in)), act in zip(irreps_in.data, acts): + if act is not None: + if l_in != 0: + raise ValueError(f"Activation cannot apply an activation function to a non-scalar input.") + + acts_out.append(_Normalize(act, dtype=dtype)) + p_out = _parity_function(acts_out[-1]) if p_in == -1 else p_in + + if p_out == 0: + raise ValueError( + "Parity is not match. The input scalar is odd but the activation is neither even nor odd." + ) + + irreps_out.append((mul, (0, p_out))) + + else: + acts_out.append(identity) + irreps_out.append((mul, (l_in, p_in))) + + self.irreps_in = irreps_in + self.irreps_out = Irreps(irreps_out) + self.acts = acts_out[:len(irreps_in)] + + def construct(self, v): + """Implement the activation function for the input tensor.""" + vs = self.irreps_in.decompose(v) + batch_shape = v.shape[:-1] + out_list = [] + i = 0 + for act in self.acts: + out_list.append(act(vs[i]).reshape(batch_shape + (self.irreps_in.data[i].dim,))) + i += 1 + + if len(out_list) > 1: + out = ops.concat(out_list, axis=-1) + elif len(out_list) == 1: + out = out_list[0] + else: + out = ops.zeros_like(v) + return out + + def __repr__(self): + acts = "".join(["x" if a is not identity else "-" for a in self.acts]) + return f"{self.__class__.__name__} [{acts}] ({self.irreps_in} -> {self.irreps_out})" diff --git a/mindscience/e3nn/nn/batchnorm.py b/mindscience/e3nn/nn/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..9a98cb0ba28df4970c41a756e480f3fa230433a5 --- /dev/null +++ b/mindscience/e3nn/nn/batchnorm.py @@ -0,0 +1,181 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""batchnorm""" + +from mindspore import nn, Parameter, ops, float32 + +from ..o3.irreps import Irreps + + +class BatchNorm(nn.Cell): + r""" + Batch normalization for orthonormal representations. + It normalizes by the norm of the representations. + Note that the norm is invariant only for orthonormal representations. + Irreducible representations `wigner_D` are orthonormal. + + Args: + irreps (Union[str, Irrep, Irreps]): the input irreps. + eps (float): avoid division by zero when we normalize by the variance. Default: ``1e-5``. + momentum (float): momentum of the running average. Default: ``0.1``. + affine (bool): do we have weight and bias parameters. Default: ``True``. + reduce (str): {'mean', 'max'}, method used to reduce. Default: ``'mean'``. + instance (bool): apply instance norm instead of batch norm. Default: ``Flase``. + normalization (str): {'component', 'norm'}, normalization method. Default: ``'component'``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + + Inputs: + - **input** (Tensor) - The shape of Tensor is :math:`(batch, ..., irreps.dim)`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(batch, ..., irreps.dim)`. + + Raises: + ValueError: If `reduce` is not in ['mean', 'max']. + ValueError: If `normalization` is not in ['component', 'norm']. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import BatchNorm + >>> from mindspore import ops, Tensor + >>> bn = BatchNorm('3x0o+2x0e+1x0o') + >>> print(bn) + BatchNorm (3x0o+2x0e+1x0o, eps=1e-05, momentum=0.1) + >>> inputs = Tensor(ops.ones((4, 6))) + >>> outputs = bn(inputs) + >>> print(outputs.shape) + (4, 6) + """ + + def __init__(self, irreps, eps=1e-5, momentum=0.1, affine=True, reduce='mean', instance=False, + normalization='component', dtype=float32): + super().__init__() + self.irreps = Irreps(irreps) + self.eps = eps + self.momentum = momentum + self.affine = affine + self.instance = instance + self.reduce = reduce + self.normalization = normalization + self.training = True + + num_scalar = sum(mul for mul, ir in self.irreps if ir.is_scalar()) + num_features = self.irreps.num_irreps + + self.running_mean = None if self.instance else Parameter(ops.zeros(num_scalar, dtype=dtype), + requires_grad=False) + self.running_var = None if self.instance else Parameter(ops.ones(num_features, dtype=dtype), + requires_grad=False) + + self.weight = Parameter(ops.ones(num_features, dtype=dtype)) if affine else None + self.bias = Parameter(ops.zeros(num_scalar, dtype=dtype)) if affine else None + + def _roll_avg(self, curr, update): + return (1 - self.momentum) * curr + self.momentum * update + + def __repr__(self): + return f"{self.__class__.__name__} ({self.irreps}, eps={self.eps}, momentum={self.momentum})" + + def construct(self, inputs): + """construct""" + inputs_shape = inputs.shape + batch = inputs_shape[0] + dim = inputs_shape[-1] + inputs = inputs.reshape(batch, -1, dim) + + new_means = [] + new_vars = [] + + fields = [] + ix = 0 + irm = 0 + irv = 0 + iw = 0 + ib = 0 + + for mir in self.irreps.data: + mul = mir.mul + ir = mir.ir + + d = ir.dim + field = inputs[:, :, ix: ix + mul * d] # [batch, sample, mul * repr] + ix += mul * d + + # (batch, sample, mul, repr) + field = field.reshape(batch, -1, mul, d) + + if ir.is_scalar(): # scalars + if self.training or self.instance: + if self.instance: + field_mean = field.mean(1).reshape(batch, mul) # [batch, mul] + else: + field_mean = field.mean([0, 1]).reshape(mul) # [mul] + new_means.append( + self._roll_avg(self.running_mean[irm:irm + mul], field_mean) + ) + else: + field_mean = self.running_mean[irm: irm + mul] + irm += mul + + # (batch, sample, mul, repr) + field = field - field_mean.reshape(-1, 1, mul, 1) + + if self.training or self.instance: + if self.normalization == 'norm': + field_norm = field.pow(2).sum(3) # [batch, sample, mul] + elif self.normalization == 'component': + field_norm = field.pow(2).mean(3) # [batch, sample, mul] + else: + raise ValueError(f"Invalid normalization option {self.normalization}") + + if self.reduce == 'mean': + field_norm = field_norm.mean(1) # [batch, mul] + elif self.reduce == 'max': + field_norm = ops.amax(field_norm, 1) # [batch, mul] + else: + raise ValueError(f"Invalid reduce option {self.reduce}") + + if not self.instance: + field_norm = field_norm.mean(0) # [mul] + new_vars.append(self._roll_avg(self.running_var[irv: irv + mul], field_norm)) + else: + field_norm = self.running_var[irv: irv + mul] + irv += mul + + field_norm = (field_norm + self.eps).pow(-0.5) # [(batch,) mul] + + if self.affine: + weight = self.weight[iw: iw + mul] # [mul] + iw += mul + + field_norm = field_norm * weight # [(batch,) mul] + + field = field * field_norm.reshape(-1, 1, mul, 1) # [batch, sample, mul, repr] + + if self.affine and ir.is_scalar(): # scalars + bias = self.bias[ib: ib + mul] # [mul] + ib += mul + field += bias.reshape(mul, 1) # [batch, sample, mul, repr] + + fields.append(field.reshape(batch, -1, mul * d)) # [batch, sample, mul * repr] + + if self.training and not self.instance: + ops.assign(self.running_mean, ops.cat(new_means)) + ops.assign(self.running_var, ops.cat(new_vars)) + + output = ops.cat(fields, 2) + return output.reshape(inputs_shape) diff --git a/mindscience/e3nn/nn/fc.py b/mindscience/e3nn/nn/fc.py new file mode 100644 index 0000000000000000000000000000000000000000..4d85dc4e2a4dba86f3caeffaaae6cf726c53ee21 --- /dev/null +++ b/mindscience/e3nn/nn/fc.py @@ -0,0 +1,110 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FullyConnectedNet""" +from mindspore import Tensor, nn, Parameter, float32, ops +from mindspore.common.initializer import initializer + +from .activation import _Normalize +from ..utils.initializer import renormal_initializer + +identity = ops.Identity() + + +class _Layer(nn.Cell): + r"""Single simple dense layer with parameter w.""" + + def __init__(self, h_in, h_out, act, init_method='normal', dtype=float32): + super().__init__() + + init_method = renormal_initializer(init_method) + + self.weight = Parameter(initializer( + init_method, (h_in, h_out), dtype), name='Layer') + self.act = act if act is not None else identity + self.h_in = h_in + self.h_out = h_out + self.weight_numel = self.weight.numel() + self.sqrt_h_in = ops.sqrt(Tensor(self.h_in, self.weight.dtype)) + + def construct(self, x): + w = self.weight / self.sqrt_h_in + x = ops.matmul(x, w) + x = self.act(x) + return x + + def __repr__(self): + return f"Layer ({self.h_in}->{self.h_out})" + + +class FullyConnectedNet(nn.SequentialCell): + r""" + Fully-connected Neural Network with normalized activation on scalars. + + Args: + h_list (List[int]): a list of input, internal and output dimensions for dense layers. + act (Func): activation function which will be automatically normalized. Default: ``None``. + out_act (bool): whether apply the activation function on the output. Default: ``False``. + init_method (Union[str, mindspore.common.initializer]): initialize parameters. Default: ``'normal'``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + + Inputs: + - **input** (Tensor) - The shape of Tensor is :math:`(h\_list[0])`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(h\_list[-1])`. + + Raises: + TypeError: If the elements `h_list` are not `int`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> from mindchemistry.e3.nn import FullyConnectedNet + >>> fc = FullyConnectedNet([4,10,20,12,6], ops.tanh) + FullyConnectedNet [4, 10, 20, 12, 6] + >>> v = ms.Tensor([.1,.2,.3,.4]) + >>> grad = ops.grad(fc, weights=fc.trainable_params()) + >>> fc(v).shape + (6,) + >>> [x.shape for x in grad(v)[1]] + [(4, 10), (10, 20), (20, 12), (12, 6)] + + """ + + def __init__(self, h_list, act=None, out_act=False, init_method='normal', dtype=float32): + self.h_list = list(h_list) + if act is not None: + act = _Normalize(act, dtype=dtype) + + self.layer_list = [] + + for i, (h1, h2) in enumerate(zip(self.h_list, self.h_list[1:])): + if not isinstance(h1, int) or not isinstance(h2, int): + raise TypeError + + if i == len(self.h_list) - 2 and (not out_act): + a = identity + else: + a = act + layer = _Layer(h1, h2, a, init_method, dtype=dtype) + self.layer_list.append(layer) + + super().__init__(self.layer_list) + self.weight_numel = sum([lay.weight_numel for lay in self.layer_list]) + + def __repr__(self): + return f"{self.__class__.__name__} ({self.h_list} | {self.weight_numel} weights)" diff --git a/mindscience/e3nn/nn/gate.py b/mindscience/e3nn/nn/gate.py new file mode 100644 index 0000000000000000000000000000000000000000..f67a35d5557d3fe55d57ca37032c2b6e1cf0c9da --- /dev/null +++ b/mindscience/e3nn/nn/gate.py @@ -0,0 +1,182 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""gate""" +from mindspore import nn, ops, float32 + +from .activation import Activation +from ..o3.irreps import Irreps +from ..o3.tensor_product import TensorProduct +from ..utils.func import narrow + + +class _Extract(nn.Cell): + """Extract tuple of tensors from irreps_in by irreps_outs with respecting instructions.""" + + def __init__(self, irreps_in, irreps_outs, instructions): + super().__init__() + self.irreps_in = Irreps(irreps_in) + self.irreps_outs = tuple(Irreps(irreps) for irreps in irreps_outs) + self.instr = instructions + + if not len(self.irreps_outs) == len(self.instr): + raise ValueError('inputs are illegal') + for irreps_out, ins in zip(self.irreps_outs, self.instr): + if not len(irreps_out) == len(ins): + raise ValueError('inputs are illegal') + + def construct(self, x): + """construct""" + out = [] + for i in range(len(self.irreps_outs)): + if self.instr[i] == tuple(range(len(self.irreps_in.data))): + out.append(x) + else: + out_i = [] + for i_in in self.instr[i]: + out_i.append(narrow(x, -1, *self.irreps_in.slice_tuples[i_in])) + if out_i: + out.append(ops.concat(out_i, -1)) + return out + + +class _Sortcut(nn.Cell): + """Sort and cut a tensor by irreps_outs.""" + + def __init__(self, *irreps_outs): + super().__init__() + self.irreps_outs = tuple(Irreps(irreps).simplify() for irreps in irreps_outs) + irreps_in = sum(self.irreps_outs, Irreps([])) + + i = 0 + instructions = [] + for irreps_out in self.irreps_outs: + instructions.append(tuple(range(i, i + len(irreps_out)))) + i += len(irreps_out) + + irreps_in, p, _ = irreps_in.sort() + instructions = [tuple(p[i] for i in x) for x in instructions] + + self.cut = _Extract(irreps_in, self.irreps_outs, instructions) + self.irreps_in = irreps_in.simplify() + + def construct(self, x): + return self.cut(x) + + +class Gate(nn.Cell): + r""" + Gate activation function. The input contain three parts: the first part `irreps_scalars` are scalars that only be + affected by activation functions `acts`; + the second part `irreps_gates` are scalars that be affected by activation functions `act_gates` and be multiplied + on the third part. + + .. math:: + \left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right) + + where :math:`x_i` and :math:`\phi_i` are from `irreps_scalars` and `acts`, and :math:`g_j`, :math:`\phi_j`, + and :math:`y_j` are from `irreps_gates`, `act_gates`, and `irreps_gated`. + + Args: + irreps_scalars (Union[str, Irrep, Irreps]): the input scalar irreps that will be passed through the + activation functions `acts`. + acts (List[Func]): a list of activation functions for each part of `irreps_scalars`. + The length of the `acts` will be clipped or filled by identity functions to match the length of + `irreps_scalars`. + irreps_gates (Union[str, Irrep, Irreps]): the input scalar irreps that will be passed through the + activation functions `act_gates` and multiplied by `irreps_gated`. + act_gates (List[Func]): a list of activation functions for each part of `irreps_gates`. + The length of the `acts` will be clipped or filled by identity functions to match the length of + `irreps_gates`. + irreps_gated (Union[str, Irrep, Irreps]): the input irreps that will be gated. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32``. + + Inputs: + - **input** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_out.dim)`. + + Raises: + ValueError: If `irreps_scalars` or `irreps_gates` contain non-scalar irrep. + ValueError: If the total multiplication of `irreps_gates` do not match the total multiplication of + `irreps_gated`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore import ops + >>> from mindchemistry.e3.nn import Gate + >>> Gate('2x0e', [ops.tanh], '1x0o+2x0e', [ops.abs], '2x1o+1x2e') + Gate (2x0e+1x0o+2x0e+2x1o+1x2e -> 2x0e+2x1o+1x2e) + """ + + def __init__(self, irreps_scalars, acts, irreps_gates, act_gates, irreps_gated, dtype=float32, ncon_dtype=float32): + super().__init__() + irreps_scalars = Irreps(irreps_scalars) + irreps_gates = Irreps(irreps_gates) + irreps_gated = Irreps(irreps_gated) + + # pylint: disable=C1801 + if len(irreps_gates) > 0 and irreps_gates.lmax > 0: + raise ValueError(f"Gate scalars must be scalars, instead got irreps_gates = {irreps_gates}") + # pylint: disable=C1801 + if len(irreps_scalars) > 0 and irreps_scalars.lmax > 0: + raise ValueError(f"Scalars must be scalars, instead got irreps_scalars = {irreps_scalars}") + if not irreps_gates.num_irreps == irreps_gated.num_irreps: + raise ValueError(f"There are {irreps_gated.num_irreps} irreps in irreps_gated, \ + but a different number ({irreps_gates.num_irreps}) of gate scalars in irreps_gates") + + self.sc = _Sortcut(irreps_scalars, irreps_gates, irreps_gated) + self.irreps_scalars, self.irreps_gates, self.irreps_gated = self.sc.irreps_outs + + if self.irreps_scalars.num_irreps == 0: + self._has_scalar = False + else: + self._has_scalar = True + self.act_pass = Activation(irreps_scalars, acts, dtype=dtype) + irreps_scalars = self.act_pass.irreps_out + self.act_gates = Activation(irreps_gates, act_gates, dtype=dtype) + irreps_gates = self.act_gates.irreps_out + + self.tp = TensorProduct(irreps_gated, irreps_gates, instructions='element', dtype=dtype, ncon_dtype=ncon_dtype) + irreps_gated = self.tp.irreps_out + + self.irreps_in = self.sc.irreps_in + self.irreps_out = irreps_scalars + irreps_gated + + def construct(self, x): + """Implement the gate activation function for the input tensor.""" + + scalars, gates, gated = self.sc(x) + if self._has_scalar: + scalars = self.act_pass(scalars) + + if gates.shape[-1] > 0: + gates = self.act_gates(gates) + gated = self.tp(gated, gates) + if self._has_scalar: + x = ops.concat([scalars, gated], axis=-1) + else: + x = gated + else: + x = scalars + + return x + + def __repr__(self): + return f"{self.__class__.__name__} ({self.irreps_in} -> {self.irreps_out})" diff --git a/mindscience/e3nn/nn/normact.py b/mindscience/e3nn/nn/normact.py new file mode 100644 index 0000000000000000000000000000000000000000..2080931fac8bec84cf578ea9f6eccee9af63abfc --- /dev/null +++ b/mindscience/e3nn/nn/normact.py @@ -0,0 +1,128 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""normact""" +from mindspore import nn, Parameter, float32, ops +from mindspore.common.initializer import initializer + +from ..o3.irreps import Irreps +from ..o3.tensor_product import TensorProduct +from ..o3.norm import Norm + + +class NormActivation(nn.Cell): + r"""Activation function for the norm of irreps. + Applies a scalar activation to the norm of each irrep and outputs a (normalized) version of that irrep multiplied + by the scalar output of the scalar activation. + + Args: + irreps_in (Union[str, Irrep, Irreps]): the input irreps. + act (Func): an activation function for each part of the norm of `irreps_in`. + normalize (bool): whether to normalize the input features before multiplying them by the scalars from the + nonlinearity. Default: True. + epsilon (float): when ``normalize``, norms smaller than ``epsilon`` will be clamped up to ``epsilon`` + to avoid division by zero. Not allowed when `normalize` is False. Default: None. + bias (bool): whether to apply a learnable additive bias to the inputs of the `act`. Default: False. + init_method (Union[str, float, mindspore.common.initializer]): initialize parameters. + Default: ``'normal'``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32``. + + Inputs: + - **input** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`. + + Raises: + ValueError: If `epsilon` is not None and `normalize` is False. + ValueError: If `epsilon` is not positive. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import NormActivation + >>> from mindspore import ops, Tensor + >>> set_context(device_id=6) + >>> norm_activation = NormActivation("2x1e", ops.sigmoid, bias=True) + >>> print(norm_activation) + NormActivation [sigmoid] (2x1e -> 2x1e) + >>> inputs = Tensor(ops.ones((4, 6))) + >>> outputs = norm_activation(inputs) + >>> print(outputs.shape) + (4, 6) + """ + + def __init__(self, + irreps_in, + act, + normalize=True, + epsilon=None, + bias=False, + init_method='zeros', + dtype=float32, + ncon_dtype=float32): + super().__init__() + + self.irreps_in = Irreps(irreps_in) + self.irreps_out = Irreps(irreps_in) + + if epsilon is None and normalize: + epsilon = 1e-8 + elif epsilon is not None and not normalize: + raise ValueError("`epsilon` and `normalize = False` don't make sense together.") + elif epsilon is not None and not epsilon > 0: + raise ValueError(f"epsilon {epsilon} is invalid, must be strictly positive.") + self.epsilon = epsilon + if self.epsilon is not None: + self._eps_squared = epsilon * epsilon + else: + self._eps_squared = 0.0 + + self.norm = Norm(irreps_in, squared=(epsilon is not None), dtype=dtype) + self.act = act + self.normalize = normalize + if bias: + self.bias = Parameter(initializer(init_method, (self.irreps_in.num_irreps,), dtype), + name=self.__class__.__name__) + else: + self.bias = None + + self.scalar_multiplier = TensorProduct(irreps_in1=self.norm.irreps_out, + irreps_in2=irreps_in, + instructions='element', + dtype=dtype, + ncon_dtype=ncon_dtype) + + def construct(self, v): + """Implement the norm-activation function for the input tensor.""" + norms = self.norm(v) + if self._eps_squared > 0: + norms[norms < self._eps_squared] = self._eps_squared + norms = ops.sqrt(norms) + + nonlin_arg = norms + if self.bias is not None: + nonlin_arg = nonlin_arg + self.bias + + scalings = self.act(nonlin_arg) + if self.normalize: + scalings = scalings / norms + + return self.scalar_multiplier(scalings, v) + + def __repr__(self): + return f"{self.__class__.__name__} [{self.act.__name__}] ({self.irreps_in} -> {self.irreps_in})" diff --git a/mindscience/e3nn/nn/one_hot.py b/mindscience/e3nn/nn/one_hot.py new file mode 100644 index 0000000000000000000000000000000000000000..262b4863b78f5b1a4044b5ebf39e5f9f3ef7dfd3 --- /dev/null +++ b/mindscience/e3nn/nn/one_hot.py @@ -0,0 +1,238 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""OneHot""" +import math + +import numpy as np + +from mindspore import Tensor, ops, nn, float32, float16 +from mindspore import numpy as mnp + +from ..o3.irreps import Irreps + +TMAP = {"MixedPrecisionType.FP16": float16, "MixedPrecisionType.FP32": float32} + +def soft_unit_step(x): + r""" + Smooth version of the unit step function. + + .. math:: + x \mapsto \theta(x) e^{-1/x} + + Args: + x (Tensor): the input tensor. + + Returns: + Tensor, the output of the unit step function. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import soft_unit_step + >>> from mindspore import ops, set_context, Tensor + >>> x = Tensor(ops.linspace(-1.0, 10.0, 1000)) + >>> outputs = soft_unit_step(x) + >>> print(outputs.shape) + (1000,) + + """ + return ops.relu(x) * ops.exp(- 1 / x) / x + + +class OneHot(nn.Cell): + r""" + One-hot embedding. + """ + + def __init__(self, num_types, dtype=float32): + super().__init__() + self.num_types = num_types + self.irreps_output = Irreps([(self.num_types, (0, 1))]) + + self.one_hot = ops.OneHot() + self.on_off = (Tensor(1., dtype=dtype), Tensor(0., dtype=dtype)) + + def construct(self, atom_type): + type_numbers = atom_type + one_hot = self.one_hot(type_numbers, self.num_types, *self.on_off) + return one_hot + + def __repr__(self): + return f'OneHot [num_types: {self.num_types}] ( -> {self.irreps_output})' + + +# pylint: disable=C0103 +# pylint: disable=R1705 +class SoftOneHotLinspace(nn.Cell): + r""" + Projection on a basis of functions. Returns a set of :math:`\{y_i(x)\}_{i=1}^N`, + + .. math:: + y_i(x) = \frac{1}{Z} f_i(x) + + where :math:`x` is the input and :math:`f_i` is the ith basis function. + :math:`Z` is a constant defined (if possible) such that, + + .. math:: + \langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1 + + Note that `bessel` basis cannot be normalized. + + Args: + start (float): minimum value span by the basis. + end (float): maximum value span by the basis. + number (int): number of basis functions :math:`N`. + basis (str): {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}, the basis family. + Default: ``'smooth_finite'``. + cutoff (bool): whether require the :math:`y_i(x)` from the outside domain of (`start`, `end`) to be + vanished. Default: ``True``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + + Inputs: + - **x** (Tensor) - The shape of Tensor is :math:`(...)`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(..., N)`. + + Raises: + ValueError: If `basis` is not in {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import SoftOneHotLinspace + >>> from mindspore import ops, Tensor + >>> soft_one_hot_linspace = SoftOneHotLinspace(-0.5, 1.5, number=4) + >>> x = Tensor(ops.ones((4, 6))) + >>> outputs = soft_one_hot_linspace(x) + >>> print(outputs.shape) + (4, 6, 4) + + """ + + def __init__(self, start, end, number, basis='smooth_finite', cutoff=True, dtype=float32): + super().__init__() + + self.start = Tensor(start, dtype=dtype) + self.end = Tensor(end, dtype=dtype) + self.number = number + self.basis = basis + self.cutoff = cutoff + + if self.cutoff: + self.values = Tensor(np.linspace(start, end, number), dtype=dtype) + self.step = self.values[1] - self.values[0] + else: + self.values = Tensor(np.linspace(start, end, number + 2), dtype=dtype) + self.step = self.values[1] - self.values[0] + self.values = self.values[1:-1] + + self.PI = Tensor(math.pi, dtype=dtype) + self.c = self.end - self.start + self.consts = [ + ops.exp(Tensor(2.0, dtype=dtype)), + ops.sqrt(Tensor(0.25 + self.number / 2, dtype=dtype)), + ops.sqrt(Tensor(2. / self.c, dtype=dtype)) + ] + self.bessel_roots = mnp.arange(1, self.number + 1) * self.PI + + def construct(self, x): + """construct""" + diff = (x.expand_dims(-1) - self.values) / self.step + + if self.basis == 'gaussian': + return ops.exp(-diff.pow(2)) / 1.12 + + elif self.basis == 'cosine': + return ops.cos(self.PI / 2 * diff) * (diff < 1) * (-1 < diff) + + elif self.basis == 'smooth_finite': + return 1.14136 * self.consts[0] * soft_unit_step(diff + 1.) * soft_unit_step(1. - diff) + + elif self.basis == 'fourier': + x = (x.expand_dims(-1) - self.start) / (self.end - self.start) + if not self.cutoff: + i = mnp.arange(0, self.number) + return ops.cos(self.PI * i * x) / self.consts[1] + else: + i = mnp.arange(1, self.number + 1) + return ops.sin(self.PI * i * x) / self.consts[1] * (x > 0) * (x < 1) + + if self.basis == 'bessel': + x = x.expand_dims(-1) - self.start + out = self.consts[2] * ops.sin(self.bessel_roots * x / self.c) / x + + if not self.cutoff: + return out + else: + return out * ((x / self.c) < 1) * (x > 0) + + else: + raise ValueError(f"Unsupported basis: {self.basis}.") + + def _set_mixed_precision_type_recursive(self, dst_type): + super()._set_mixed_precision_type_recursive(dst_type) + self.values = self.values.astype(TMAP[dst_type.__str__()]) + for i in range(len(self.consts)): + self.consts[i] = self.consts[i].astype(TMAP[dst_type.__str__()]) + + +def soft_one_hot_linspace(x, start, end, number, basis='smooth_finite', cutoff=True): + r""" + Projection on a basis of functions. Returns a set of :math:`\{y_i(x)\}_{i=1}^N`, + + .. math:: + y_i(x) = \frac{1}{Z} f_i(x) + + where :math:`x` is the input and :math:`f_i` is the ith basis function. + :math:`Z` is a constant defined (if possible) such that, + + .. math:: + \langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1 + + Note that `bessel` basis cannot be normalized. + + Args: + x (Tensor): The shape of Tensor is :math:`(...)`. + start (float): minimum value span by the basis. + end (float): maximum value span by the basis. + number (int): number of basis functions :math:`N`. + basis (str): {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}, the basis family. + Default: ``'smooth_finite'``. + cutoff (bool): whether require the :math:`y_i(x)` from the outside domain of (`start`, `end`) to be + vanished. Default: ``True``. + + Returns: + Tensor, shape is :math:`(..., N)`. + + Raises: + ValueError: If `basis` is not in {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import soft_one_hot_linspace + >>> from mindspore import ops, Tensor + >>> x = Tensor(ops.ones((4, 6))) + >>> outputs = soft_one_hot_linspace(x, -0.5, 1.5, number=4) + >>> print(outputs.shape) + (4, 6, 4) + + """ + soft = SoftOneHotLinspace(start, end, number, basis=basis, cutoff=cutoff, dtype=x.dtype) + return soft(x) diff --git a/mindscience/e3nn/nn/scatter.py b/mindscience/e3nn/nn/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..922ac15ef17bd3e417ad6b190e7518c7b40dbe96 --- /dev/null +++ b/mindscience/e3nn/nn/scatter.py @@ -0,0 +1,74 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""scatter""" +from mindspore import ops, nn +from mindspore.ops import operations as P + + +class Scatter(nn.Cell): + r""" + Easy-use version of scatter. + + Args: + mode (str): {'add', 'sum', 'div', 'max', 'min', 'mul'}, scatter mode. + + Raises: + ValueError: If `mode` is not legal. + + Supported Platforms: + ``CPU`` ``GPU`` ``Ascend`` + + """ + + def __init__(self, mode='add'): + super().__init__() + self.mode = mode + if mode in ('add', 'sum'): + self.scatter = P.TensorScatterAdd() + elif mode == 'div': + self.scatter = P.TensorScatterDiv() + elif mode == 'max': + self.scatter = P.TensorScatterMax() + elif mode == 'min': + self.scatter = P.TensorScatterMin() + elif mode == 'mul': + self.scatter = P.TensorScatterMul() + else: + raise ValueError(f"Unexpected scatter mode {mode}") + + self.zeros = ops.Zeros() + + def construct(self, src, index, out=None, dim_size=None): + r""" + Args: + src (Tensor): The source tensor. + index (Tensor): The indices of elements to scatter. + out (Tensor): The destination tensor. Default: None. + dim_size (int): If `out` is not given, automatically create output with size `dim_size`. + If `dim_size` is not given, a minimal sized output tensor is returned. Default: None. + + Returns: + Tensor. + """ + if index.ndim < 2: + index = index.unsqueeze(-1) + if out is not None: + return self.scatter(out, index, src) + dim_size = src.shape[0] if dim_size is None else dim_size + zero = self.zeros((dim_size, src.shape[1]), src.dtype) + return self.scatter(zero, index, src) + + def __repr__(self): + return f'Scatter [{self.mode}]' diff --git a/mindscience/e3nn/o3/__init__.py b/mindscience/e3nn/o3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44b27fe6d8d7cec2757486b8930a93cc2de6a77c --- /dev/null +++ b/mindscience/e3nn/o3/__init__.py @@ -0,0 +1,51 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .irreps import Irrep, Irreps +from .rotation import * +from .wigner import change_basis_real_to_complex, su2_generators, so3_generators, wigner_D, wigner_3j +from .spherical_harmonics import SphericalHarmonics, spherical_harmonics +from .tensor_product import TensorProduct +from .sub import * +from .norm import Norm + +__all__ = [ + "Irrep", + "Irreps", + "identity_angles", + "rand_angles", + "compose_angles", + "matrix_x", + "matrix_y", + "matrix_z", + "angles_to_matrix", + "matrix_to_angles", + "angles_to_xyz", + "xyz_to_angles", + "change_basis_real_to_complex", + "su2_generators", + "so3_generators", + "wigner_D", + "wigner_3j", + "TensorProduct", + "SphericalHarmonics", + "spherical_harmonics", + "FullyConnectedTensorProduct", + "FullTensorProduct", + "ElementwiseTensorProduct", + "Linear", + "TensorSquare", + "Norm", +] diff --git a/mindscience/e3nn/o3/irreps.py b/mindscience/e3nn/o3/irreps.py new file mode 100644 index 0000000000000000000000000000000000000000..01273bf9d2a1df5428118cedbb5f1a6b9d6996a0 --- /dev/null +++ b/mindscience/e3nn/o3/irreps.py @@ -0,0 +1,761 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import itertools +import collections +import dataclasses + +import numpy as np + +from mindspore import jit_class, Tensor, ops + +from .wigner import wigner_D +from .rotation import matrix_to_angles +from ..utils.func import broadcast_args, _to_tensor, norm_keep, _expand_last_dims, narrow +from ..utils.perm import _inverse +from ..utils.linalg import _direct_sum + +# pylint: disable=C0111 + +@jit_class +@dataclasses.dataclass(init=False, frozen=True) +class Irrep: + r""" + Irreducible representation of O(3). This class does not contain any data, it is a structure that describe the representation. + It is typically used as argument of other classes of the library to define the input and output representations of functions. + + Args: + l (Union[int, str]): non-negative integer, the degree of the representation, :math:`l = 0, 1, \dots`. Or string to indicate the degree and parity. + p (int): {1, -1}, the parity of the representation. Default: ``None``. + + Raises: + NotImplementedError: If method is not implemented. + ValueError: If `l` is negative or `p` is not in {1, -1}. + ValueError: If `l` cannot be converted to an `Irrep`. + TypeError: If `l` is not int or str. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import Irrep + >>> Irrep(0, 1) + 0e + >>> Irrep("1y") + 1o + >>> Irrep("2o").dim + 5 + >>> Irrep("2e") in Irrep("1o") * Irrep("1o") + True + >>> Irrep("1o") + Irrep("2o") + 1x1o+1x2o + """ + l: int + p: int + + def __init__(self, l, p=None): + if p is None: + if isinstance(l, Irrep): + p = l.p + l = l.l + + if isinstance(l, _MulIr): + p = l.ir.p + l = l.ir.l + + if isinstance(l, str): + try: + name = l.strip() + l = int(name[:-1]) + if l < 0: + raise ValueError + p = { + 'e': 1, + 'o': -1, + 'y': (-1) ** l, + }[name[-1]] + except Exception: + raise ValueError + elif isinstance(l, tuple): + l, p = l + + if not isinstance(l, int): + raise TypeError + elif l < 0: + raise ValueError + if p not in [-1, 1]: + raise ValueError + object.__setattr__(self, "l", l) + object.__setattr__(self, "p", p) + + def __repr__(self): + """Representation of the Irrep.""" + p = {+1: 'e', -1: 'o'}[self.p] + return f"{self.l}{p}" + + @classmethod + def iterator(cls, lmax=None): + for l in itertools.count(): + yield Irrep(l, (-1) ** l) + yield Irrep(l, -(-1) ** l) + + if l == lmax: + break + + def wigD_from_angles(self, alpha, beta, gamma, k=None): + r""" + Representation wigner D matrices of O(3) from Euler angles. + + Args: + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second. + gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first. + k (Union[None, Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): How many times the parity is applied. Default: ``None`` . + + Returns: + Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` . + + Examples: + >>> m = Irrep(1, -1).wigD_from_angles(0, 0 ,0, 1) + >>> print(m) + [[-1, 0, 0], + [ 0, -1, 0], + [ 0, 0, -1]] + """ + if k is None: + k = ops.zeros_like(_to_tensor(alpha)) + + alpha, beta, gamma, k = broadcast_args(alpha, beta, gamma, k) + return wigner_D(self.l, alpha, beta, gamma) * self.p ** _expand_last_dims(k) + + def wigD_from_matrix(self, R): + r""" + Representation wigner D matrices of O(3) from rotation matrices. + + Args: + R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`. + + Returns: + Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`. + + Raises: + TypeError: If `R` is not a Tensor. + + Examples: + >>> from mindspore import ops + >>> m = Irrep(1, -1).wigD_from_matrix(-ops.eye(3)) + >>> print(m) + [[-1, 0, 0], + [ 0, -1, 0], + [ 0, 0, -1]] + """ + if not isinstance(R, Tensor): + raise TypeError + d = Tensor(np.sign(np.linalg.det(R.asnumpy()))) + R = _expand_last_dims(d) * R + k = (1. - d) / 2 + return self.wigD_from_angles(*matrix_to_angles(R), k) + + @property + def dim(self) -> int: + return 2 * self.l + 1 + + def is_scalar(self) -> bool: + return self.l == 0 and self.p == 1 + + def __mul__(self, other): + r""" + Generate the irreps from the product of two irreps. + + Returns: + generator of `Irrep`. + """ + other = Irrep(other) + p = self.p * other.p + lmin = abs(self.l - other.l) + lmax = self.l + other.l + for l in range(lmin, lmax + 1): + yield Irrep(l, p) + + def __rmul__(self, other): + r""" + Return `Irreps` of multiple `Irrep`. + + Args: + other (int): multiple number of the `Irrep`. + + Returns: + `Irreps` - corresponding multiple `Irrep`. + + Raises: + TypeError: If `other` is not int. + """ + if not isinstance(other, int): + raise TypeError + return Irreps([(other, self)]) + + def __add__(self, other): + r"""Sum of two irreps.""" + return Irreps(self) + Irreps(other) + + def __radd__(self, other): + r"""Sum of two irreps.""" + return Irreps(other) + Irreps(self) + + def __iter__(self): + r"""Deconstruct the irrep into ``l`` and ``p``.""" + yield self.l + yield self.p + + def __lt__(self, other): + r"""Compare the order of two irreps.""" + return (self.l, self.p) < (other.l, other.p) + + def __eq__(self, other): + """Compare two irreps.""" + other = Irrep(other) + return (self.l, self.p) == (other.l, other.p) + + +@jit_class +@dataclasses.dataclass(init=False, frozen=True) +class _MulIr: + """Multiple Irrep.""" + mul: int + ir: Irrep + + def __init__(self, mul, ir=None): + if ir is None: + mul, ir = mul + + if not (isinstance(mul, int) and isinstance(ir, Irrep)): + raise TypeError + object.__setattr__(self, "mul", mul) + object.__setattr__(self, "ir", ir) + + @property + def dim(self): + return self.mul * self.ir.dim + + def __repr__(self): + """Representation of the irrep.""" + return f"{self.mul}x{self.ir}" + + def __iter__(self): + """Deconstruct the mulirrep into `mul` and `ir`.""" + yield self.mul + yield self.ir + + def __lt__(self, other): + """Compare the order of two mulirreps.""" + return (self.ir, self.mul) < (other.ir, other.mul) + + def __eq__(self, other): + """Compare two irreps.""" + return (self.mul, self.ir) == (other.mul, other.ir) + + +@jit_class +@dataclasses.dataclass(init=False, frozen=False) +class Irreps: + r""" + Direct sum of irreducible representations of O(3). This class does not contain any data, it is a structure that describe the representation. + It is typically used as argument of other classes of the library to define the input and output representations of functions. + + Args: + irreps (Union[str, Irrep, Irreps, List[Tuple[int]]]): a string to represent the direct sum of irreducible representations. + + Raises: + ValueError: If `irreps` cannot be converted to an `Irreps`. + ValueError: If the mul part of `irreps` part is negative. + TypeError: If the mul part of `irreps` part is not int. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import Irreps + >>> x = Irreps([(100, (0, 1)), (50, (1, 1))]) + 100x0e+50x1e + >>> x.dim + 250 + >>> Irreps("100x0e+50x1e+0x2e") + 100x0e+50x1e+0x2e + >>> Irreps("100x0e+50x1e+0x2e").lmax + 1 + >>> Irrep("2e") in Irreps("0e+2e") + True + >>> Irreps(), Irreps("") + (, ) + >>> Irreps('2x1o+1x0o') * Irreps('2x1o+1x0e') + 4x0e+1x0o+2x1o+4x1e+2x1e+4x2e + """ + __slots__ = ('data', 'dim', 'slice', 'slice_tuples') + + def __init__(self, irreps=None): + if isinstance(irreps, Irreps): + self.data = irreps.data + self.dim = irreps.dim + self.slice = irreps.slice + self.slice_tuples = irreps.slice_tuples + else: + out = () + if isinstance(irreps, Irrep): + out += (_MulIr(1, Irrep(irreps)),) + elif isinstance(irreps, _MulIr): + out += (irreps,) + elif isinstance(irreps, str): + try: + if irreps.strip() != "": + for mir in irreps.split('+'): + if 'x' in mir: + mul, ir = mir.split('x') + mul = int(mul) + ir = Irrep(ir) + else: + mul = 1 + ir = Irrep(mir) + + if not isinstance(mul, int): + raise TypeError + elif mul < 0: + raise ValueError + out += (_MulIr(mul, ir),) + except Exception: + raise ValueError + elif irreps is None: + pass + else: + out = self.handle_irreps(irreps, out) + self.data = out + self.dim = self._dim() + self.slice = self._slices() + self.slice_tuples = [(s.start, s.stop - s.start) for s in self.slice] + + def handle_irreps(self, irreps, out): + for mir in irreps: + + if isinstance(mir, str): + if 'x' in mir: + mul, ir = mir.split('x') + mul = int(mul) + ir = Irrep(ir) + else: + mul = 1 + ir = Irrep(mir) + elif isinstance(mir, Irrep): + mul = 1 + ir = mir + elif isinstance(mir, _MulIr): + mul, ir = mir + elif isinstance(mir, int): + mul, ir = 1, Irrep(l=mir, p=1) + elif len(mir) == 2: + mul, ir = mir + ir = Irrep(ir) + + if not (isinstance(mul, int) and mul >= 0 and ir is not None): + raise ValueError + + out += (_MulIr(mul, ir),) + return out + + def __iter__(self): + return iter(self.data) + + def __hash__(self): + return hash(self.data) + + def __len__(self): + return len(self.data) + + def __repr__(self): + """Representation of the irreps.""" + return "+".join(f"{mir}" for mir in self.data) + + def __eq__(self, other): + """Compare two irreps.""" + other = Irreps(other) + if not len(self) == len(other): + return False + for m_1, m_2 in zip(self.data, other.data): + if not m_1 == m_2: + return False + return True + + def __contains__(self, ir): + """Check if an irrep or an irreps is in the representation.""" + try: + ir = Irrep(ir) + return ir in (irrep for _, irrep in self.data) + except: + irreps = Irreps(ir) + m, n = len(irreps), len(self) + mask = [False] * n + + def dfs(i): + if i == m: + return True + for j in range(n): + if not mask[j]: + if irreps.data[i].mul <= self.data[j].mul and irreps.data[i].ir == self.data[j].ir: + mask[j] = True + found = dfs(i + 1) + if found: + return True + mask[j] = False + return False + + return dfs(0) + + def __add__(self, irreps): + irreps = Irreps(irreps) + return Irreps(self.data.__add__(irreps.data)) + + def __mul__(self, other): + r""" + Return `Irreps` of multiple `Irreps`. + + Args: + other (int): multiple number of the `Irreps`. + + Returns: + `Irreps` - corresponding multiple `Irreps`. + + Raises: + NotImplementedError: If `other` is `Irreps`, please use `o3.TensorProduct`. + """ + if isinstance(other, Irreps): + res = Irreps() + for mir_1 in self.data: + for mir_2 in other.data: + out_ir = mir_1.ir * mir_2.ir + for ir in out_ir: + res += mir_1.mul * mir_2.mul * ir + res, p, _ = res.simplify().sort() + return res + return Irreps([(mul * other, ir) for mul, ir in self.data]) + + def __rmul__(self, other): + r""" + Return repeated `Irreps` of multiple `Irreps`. + + Args: + other (int): multiple number of the `Irreps`. + + Returns: + `Irreps` - repeated multiple `Irreps`. + """ + return self * other + + def _dim(self): + """The dimension of the representation, :math:`2 l + 1`.""" + return sum(mul * ir.dim for mul, ir in self.data) + + @property + def num_irreps(self): + return sum(mul for mul, _ in self.data) + + @property + def ls(self): + res = [] + for mul, (l, _) in self.data: + res.extend([l] * mul) + return res + + @property + def lmax(self): + if len(self) == 0: + raise ValueError("Cannot get lmax of empty Irreps") + return max(self.ls) + + def count(self, ir): + r""" + Multiplicity of `ir`. + + Args: + ir (Irrep): `Irrep` + + Returns: + int, total multiplicity of `ir`. + + Examples: + >>> Irreps("1o + 3x2e").count("2e") + 3 + """ + ir = Irrep(ir) + res = 0 + for mul, irrep in self.data: + if ir == irrep: + res += mul + return res + + def simplify(self): + """ + Simplify the representations. + + Returns: + `Irreps` + + Examples: + >>> Irreps("1e + 1e + 0e").simplify() + 2x1e+1x0e + >>> Irreps("1e + 1e + 0e + 1e").simplify() + 2x1e+1x0e+1x1e + """ + out = [] + for mul, ir in self.data: + if out and out[-1][1] == ir: + out[-1] = (out[-1][0] + mul, ir) + elif mul > 0: + out.append((mul, ir)) + return Irreps(out) + + def remove_zero_multiplicities(self): + """ + Remove any irreps with multiplicities of zero. + + Returns: + `Irreps` + + Examples: + >>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities() + 4x0e+2x3e + """ + out = [(mul, ir) for mul, ir in self.data if mul > 0] + return Irreps(out) + + def _slices(self): + r""" + List of slices corresponding to indices for each irrep. + + Examples: + >>> Irreps('2x0e + 1e').slices() + [slice(0, 2, None), slice(2, 5, None)] + """ + s = [] + i = 0 + for mir in self.data: + s.append(slice(i, i + mir.dim)) + i += mir.dim + return s + + def sort(self): + r""" + Sort the representations by increasing degree. + + Returns: + irreps (`Irreps`) - sorted `Irreps` + + p (tuple[int]) - permute orders. `p[old_index] = new_index` + + inv (tuple[int]) - inversed permute orders. `p[new_index] = old_index` + + Examples: + >>> Irreps("1e + 0e + 1e").sort().irreps + 1x0e+1x1e+1x1e + >>> Irreps("2o + 1e + 0e + 1e").sort().p + (3, 1, 0, 2) + >>> Irreps("2o + 1e + 0e + 1e").sort().inv + (2, 1, 3, 0) + """ + Ret = collections.namedtuple("sort", ["irreps", "p", "inv"]) + out = [(ir, i, mul) for i, (mul, ir) in enumerate(self.data)] + out = sorted(out) + inv = tuple(i for _, i, _ in out) + p = _inverse(inv) + irreps = Irreps([(mul, ir) for ir, _, mul in out]) + return Ret(irreps, p, inv) + + def filter(self, keep=None, drop=None): + r""" + Filter the `Irreps` by either `keep` or `drop`. + + Args: + keep (Union[str, Irrep, Irreps, List[str, Irrep]]): list of irrep to keep. Default: None. + drop (Union[str, Irrep, Irreps, List[str, Irrep]]): list of irrep to drop. Default: None. + + Returns: + `Irreps`, filtered irreps. + + Raises: + ValueError: If both `keep` and `drop` are not `None`. + + Examples: + >>> Irreps("1o + 2e").filter(keep="1o") + 1x1o + >>> Irreps("1o + 2e").filter(drop="1o") + 1x2e + """ + if keep is None and drop is None: + return self + if keep is not None and drop is not None: + raise ValueError("Cannot specify both keep and drop") + if keep is not None: + keep = Irreps(keep).data + keep = {mir.ir for mir in keep} + return Irreps([(mul, ir) for mul, ir in self.data if ir in keep]) + if drop is not None: + drop = Irreps(drop).data + drop = {mir.ir for mir in drop} + return Irreps([(mul, ir) for mul, ir in self.data if not ir in drop]) + return None + + def decompose(self, v, batch=False): + r""" + Decompose a vector by `Irreps`. + + Args: + v (Tensor): the vector to be decomposed. + batch (bool): whether reshape the result such that there is at least a batch dimension. Default: `False`. + + Returns: + List of Tensors, the decomposed vectors by `Irreps`. + + Raises: + TypeError: If v is not Tensor. + ValueError: If length of the vector `v` is not matching with dimension of `Irreps`. + + Examples: + >>> import mindspore as ms + >>> input = ms.Tensor([1, 2, 3]) + >>> m = Irreps("1o").decompose(input) + >>> print(m) + [Tensor(shape=[1,3], dtype=Int64, value= + [[1,2,3]])] + """ + if not isinstance(v, Tensor): + raise TypeError( + f"The input for decompose should be Tensor, but got {type(v)}.") + len_v = v.shape[-1] + if not self.dim == len_v: + raise ValueError( + f"the shape of input {v.shape[-1]} do not match irreps dimension {self.dim}.") + + res = [] + batch_shape = v.shape[:-1] + for (s, l), mir in zip(self.slice_tuples, self.data): + v_slice = narrow(v, -1, s, l) + if v.ndim == 1 and batch: + res.append(v_slice.reshape( + (1,) + batch_shape + (mir.mul, mir.ir.dim))) + else: + res.append(v_slice.reshape( + batch_shape + (mir.mul, mir.ir.dim))) + + return res + + @staticmethod + def spherical_harmonics(lmax, p=-1): + r""" + Representation of the spherical harmonics. + + Args: + lmax (int): maximum of `l`. + p (int): {1, -1}, the parity of the representation. + + Returns: + `Irreps`, representation of :math:`(Y^0, Y^1, \dots, Y^{\mathrm{lmax}})`. + + Examples: + >>> Irreps.spherical_harmonics(3) + 1x0e+1x1o+1x2e+1x3o + >>> Irreps.spherical_harmonics(4, p=1) + 1x0e+1x1e+1x2e+1x3e+1x4e + """ + return Irreps([(1, (l, p ** l)) for l in range(lmax + 1)]) + + def randn(self, *size, normalization='component'): + r""" + Random tensor. + + Args: + *size (List[int]): size of the output tensor, needs to contains a `-1`. + normalization (str): {'component', 'norm'}, type of normalization method. + + Returns: + Tensor, the shape is `size` where `-1` is replaced by `self.dim`. + + Examples: + >>> Irreps("5x0e + 10x1o").randn(5, -1, 5, normalization='norm').shape + (5, 35, 5) + """ + di = size.index(-1) + lsize = size[:di] + rsize = size[di + 1:] + + if normalization == 'component': + return ops.standard_normal((*lsize, self.dim, *rsize)) + elif normalization == 'norm': + x_list = [] + for s, (mul, ir) in zip(self.slice, self.data): + if mul < 1: + continue + r = ops.standard_normal((*lsize, mul, ir.dim, *rsize)) + r = r / norm_keep(r, axis=di + 1) + + x_list.append(r.reshape((*lsize, -1, *rsize))) + return ops.concat(x_list, axis=di) + else: + raise ValueError("Normalization needs to be 'norm' or 'component'") + + def wigD_from_angles(self, alpha, beta, gamma, k=None): + r""" + Representation wigner D matrices of O(3) from Euler angles. + + Args: + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second. + gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first. + k (Union[None, Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): How many times the parity is applied. Default: None. + + Returns: + Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` + + Examples: + >>> m = Irreps("1o").wigD_from_angles(0, 0 ,0, 1) + >>> print(m) + [[-1, 0, 0], + [ 0, -1, 0], + [ 0, 0, -1]] + """ + return _direct_sum(*[ir.wigD_from_angles(alpha, beta, gamma, k) for mul, ir in self for _ in range(mul)]) + + def wigD_from_matrix(self, R): + r""" + Representation wigner D matrices of O(3) from rotation matrices. + + Args: + R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`. + + Returns: + Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` + + Raises: + TypeError: If `R` is not a Tensor. + + Examples: + >>> m = Irreps("1o").wigD_from_matrix(-ops.eye(3)) + >>> print(m) + [[-1, 0, 0], + [ 0, -1, 0], + [ 0, 0, -1]] + """ + if not isinstance(R, Tensor): + raise TypeError + d = Tensor(np.sign(np.linalg.det(R.asnumpy()))) + R = _expand_last_dims(d) * R + k = (1 - d) / 2 + return self.wigD_from_angles(*matrix_to_angles(R), k) diff --git a/mindscience/e3nn/o3/norm.py b/mindscience/e3nn/o3/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..150e52178bd619c1f3428c7175384cc00662526a --- /dev/null +++ b/mindscience/e3nn/o3/norm.py @@ -0,0 +1,81 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""norm""" +from mindspore import nn, ops, float32 + +from .irreps import Irreps +from .tensor_product import TensorProduct + + +class Norm(nn.Cell): + r""" + Norm of each irrep in a direct sum of irreps. + + Args: + irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. + squared (bool): whether to return the squared norm. Default: False. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Inputs: + - **v** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)` . + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_out.dim)` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> from mindchemistry.e3.o3 import Norm + >>> n = Norm('3x1o') + >>> v = ms.Tensor(np.linspace(1., 2., n.irreps_in.dim), dtype=ms.float32) + >>> n(v).shape + (1, 3) + + """ + + def __init__(self, irreps_in, squared=False, dtype=float32, ncon_dtype=float32): + super().__init__() + + self.squared = squared + irreps_in = Irreps(irreps_in).simplify() + irreps_out = Irreps([(mul, "0e") for mul, _ in irreps_in]) + + instr = [(i, i, i, "uuu", False, ir.dim) for i, (mul, ir) in enumerate(irreps_in)] + + self.tp = TensorProduct(irreps_in, + irreps_in, + irreps_out, + instr, + irrep_norm="component", + dtype=dtype, + ncon_dtype=ncon_dtype) + + self.irreps_in = irreps_in + self.irreps_out = irreps_out.simplify() + + def construct(self, v): + """Implement the norm-activation function for the input tensor.""" + out = self.tp(v, v) + if self.squared: + return out + return ops.sqrt(ops.relu(out)) + + def __repr__(self): + return f"{self.__class__.__name__} ({self.irreps_in})" diff --git a/mindscience/e3nn/o3/rotation.py b/mindscience/e3nn/o3/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..96bbe21cc4be755124cf97043f55733d6158a11f --- /dev/null +++ b/mindscience/e3nn/o3/rotation.py @@ -0,0 +1,387 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""rotation""" +import math +import random + +import numpy as np + +from mindspore import Tensor, float32, ops + +from ..utils.func import broadcast_args, _to_tensor, norm_keep + +seed = int(random.random() * 10000) +zeros = ops.Zeros() +cos = ops.Cos() +sin = ops.Sin() +rand = ops.UniformReal(seed=seed) + + +def identity_angles(*shape, dtype=float32): + r""" + Give the identity set of Euler angles. + + Args: + shape (Tuple[int]): The shape of additional dimensions. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + + Returns: + alpha (Tensor) - The alpha Euler angles. + + beta (Tensor) - The beta Euler angles. + + gamma (Tensor) - The gamma Euler angles. + + Raises: + TypeError: If dtype of 'shape' is not tuple. + TypeError: If dtype of the element of 'shape' is not int. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import identity_angles + >>> m = identity_angles((1)) + >>> print(m) + (Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]), Tensor(shape=[1], dtype=Float32, + value= [ 0.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00])) + """ + if not isinstance(shape, tuple): + raise TypeError + if not all(map(lambda x: isinstance(x, int), shape)): + raise TypeError + abc = zeros((3,) + shape, dtype) + return abc[0], abc[1], abc[2] + + +def rand_angles(*shape): + r""" + Give a random set of Euler angles. + + Args: + shape (Tuple[int]): The shape of additional dimensions. + + Returns: + alpha (Tensor) - The alpha Euler angles. + + beta (Tensor) - The beta Euler angles. + + gamma (Tensor) - The gamma Euler angles. + + Raises: + TypeError: If dtype of 'shape' is not tuple. + TypeError: If dtype of the element of 'shape' is not int. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import rand_angles + >>> m = rand_angles((1)) + >>> print(m) + (Tensor(shape=[1], dtype=Float32, value= [ 4.00494671e+00]), Tensor(shape=[1], dtype=Float32, + value= [ 1.29240000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 5.71690750e+00])) + """ + if not isinstance(shape, tuple): + raise TypeError + if not all(map(lambda x: isinstance(x, int), shape)): + raise TypeError + alpha, gamma = 2 * math.pi * rand((2,) + shape) + beta = ops.acos(2 * rand(shape) - 1) + return alpha, beta, gamma + + +def compose_angles(a1, b1, c1, a2, b2, c2): + r""" + Computes the composed Euler angles of two sets of Euler angles. + + .. math:: + + R(a, b, c) = R(a_1, b_1, c_1) \circ R(a_2, b_2, c_2) + + Note: + The second set of Euler angles 'a2, b2, c2' are applied first, while the first set of Euler angles a2, b2, c2' + are applied Second. + The elements of Euler angles should be one of the following types: float, float32, np.float32. + + Args: + a1 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The second applied alpha Euler angles. + b1 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The second applied beta Euler angles. + c1 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The second applied gamma Euler angles. + a2 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The first applied alpha Euler angles. + b2 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The first applied beta Euler angles. + c2 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The first applied gamma Euler angles. + + Returns: + - alpha (Tensor), The composed alpha Euler angles. + - beta (Tensor), The composed beta Euler angles. + - gamma (Tensor), The composed gamma Euler angles. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import compose_angles + >>> m = compose_angles(0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + >>> print(m) + (Tensor(shape=[], dtype=Float32, value= 1.34227), Tensor(shape=[], dtype=Float32, value= 1.02462), + Tensor(shape=[], dtype=Float32, value= 1.47115)) + """ + + a1, b1, c1, a2, b2, c2 = broadcast_args(a1, b1, c1, a2, b2, c2) + return matrix_to_angles( + ops.matmul(angles_to_matrix(a1, b1, c1), angles_to_matrix(a2, b2, c2))) + + +def matrix_x(angle): + r""" + Give the rotation matrices around x axis for given angle. + + Args: + angle (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The rotation angles around x axis. + The shape of 'angle' is :math:`(...)`. + + Returns: + Tensor, the rotation matrices around x axis. The shape of output is :math:`(..., 3, 3)` + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import matrix_x + >>> m = matrix_x(0.4) + >>> print(m) + [[ 1. 0. 0. ] + [ 0. 0.92106086 -0.38941833] + [ 0. 0.38941833 0.92106086]] + """ + angle = _to_tensor(angle) + o = ops.ones_like(angle) + z = ops.zeros_like(angle) + return ops.stack([ + ops.stack([o, z, z], axis=-1), + ops.stack([z, cos(angle), -sin(angle)], axis=-1), + ops.stack([z, sin(angle), cos(angle)], axis=-1), + ], + axis=-2) + + +def matrix_y(angle): + r""" + Give the rotation matrices around y axis for given angle. + + Args: + angle (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The rotation angles around y axis. + + Returns: + Tensor, the rotation matrices around y axis. The shape of output is :math:`(..., 3, 3)` + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import matrix_y + >>> m = matrix_y(0.5) + >>> print(m) + [[ 0.87758255 0. 0.47942555] + [ 0. 1. 0. ] + [-0.47942555 0. 0.87758255]] + """ + angle = _to_tensor(angle) + o = ops.ones_like(angle) + z = ops.zeros_like(angle) + return ops.stack([ + ops.stack([cos(angle), z, sin(angle)], axis=-1), + ops.stack([z, o, z], axis=-1), + ops.stack([-sin(angle), z, cos(angle)], axis=-1), + ], + axis=-2) + + +def matrix_z(angle): + r""" + Give the rotation matrices around z axis for given angle. + + Args: + angle (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The rotation angles around z axis. + The shape of 'angle' is :math:`(...)`. + + Returns: + Tensor, the rotation matrices around z axis. The shape of output is :math:`(..., 3, 3)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import matrix_z + >>> m = matrix_z(0.6) + >>> print(m) + [[ 0.8253357 -0.5646425 0. ] + [ 0.5646425 0.8253357 0. ] + [ 0. 0. 1. ]] + """ + angle = _to_tensor(angle) + o = ops.ones_like(angle) + z = ops.zeros_like(angle) + return ops.stack([ + ops.stack([cos(angle), -sin(angle), z], axis=-1), + ops.stack([sin(angle), cos(angle), z], axis=-1), + ops.stack([z, z, o], axis=-1), + ], + axis=-2) + + +def angles_to_matrix(alpha, beta, gamma): + r""" + Conversion from angles to matrix. + + Args: + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The alpha Euler angles. The shape of Tensor is :math:`(...)`. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The beta Euler angles. The shape of Tensor is :math:`(...)`. + gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The gamma Euler angles. The shape of Tensor is :math:`(...)`. + + Returns: + Tensor, the rotation matrices. Matrices of shape :math:`(..., 3, 3)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import angles_to_matrix + >>> m = angles_to_matrix(0.4, 0.5, 0.6) + >>> print(m) + [[ 0.5672197 0.1866971 0.8021259 ] + [ 0.27070403 0.87758255 -0.395687 ] + [-0.77780527 0.44158012 0.4472424 ]] + """ + alpha, beta, gamma = broadcast_args(alpha, beta, gamma) + return ops.matmul(ops.matmul(matrix_y(alpha), matrix_x(beta)), + matrix_y(gamma)) + + +def matrix_to_angles(r_param): + r""" + Conversion from matrix to angles. + + Args: + r_param (Tensor): The rotation matrices. Matrices of shape :math:`(..., 3, 3)`. + + Returns: + - alpha (Tensor), The alpha Euler angles. The shape of Tensor is :math:`(...)`. + - beta (Tensor), The beta Euler angles. The shape of Tensor is :math:`(...)`. + - gamma (Tensor), The gamma Euler angles. The shape of Tensor is :math:`(...)`. + + Raise: + ValueError: If the det(R) is not equal to 1. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> from mindchemistry.e3.o3 import matrix_to_angles + >>> input = ms.Tensor([[0.5672197, 0.1866971, 0.8021259], [0.27070403, 0.87758255, -0.395687], + ... [-0.77780527, 0.44158012,0.4472424]]) + >>> m = matrix_to_angles(input) + >>> print(m) + (Tensor(shape=[], dtype=Float32, value= 0.4), Tensor(shape=[], dtype=Float32, value= 0.5), + Tensor(shape=[], dtype=Float32, value= 0.6)) + """ + if not np.allclose(np.linalg.det(r_param.asnumpy()), 1., 1e-3, 1e-5): + raise ValueError + + x = ops.matmul(r_param, Tensor([0.0, 1.0, 0.0])) + a, b = xyz_to_angles(x) + tmp_r_param = angles_to_matrix(a, b, ops.zeros_like(a)) + perm = tuple(range(len(tmp_r_param.shape))) + r_param = ops.matmul( + tmp_r_param.transpose(perm[:-2] + (perm[-1],) + (perm[-2],)), + r_param) + c = ops.atan2(r_param[..., 0, 2], r_param[..., 0, 0]) + return a, b, c + + +def angles_to_xyz(alpha, beta): + r""" + Convert :math:`(\alpha, \beta)` into a point :math:`(x, y, z)` on the sphere. + + Args: + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The alpha Euler angles. The shape of Tensor is :math:`(...)`. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The beta Euler angles. The shape of Tensor is :math:`(...)`. + + Returns: + Tensor, the point :math:`(x, y, z)` on the sphere. The shape of Tensor is :math:`(..., 3)` + + Supported Platforms: + ``Ascend`` + + Examples + >>> import mindspore as ms + >>> from mindchemistry.e3.o3 import angles_to_xyz + >>> print(angles_to_xyz(ms.Tensor(1.7), ms.Tensor(0.0)).abs()) + [0., 1., 0.] + """ + alpha, beta = broadcast_args(alpha, beta) + x = sin(beta) * sin(alpha) + y = cos(beta) + z = sin(beta) * cos(alpha) + return ops.stack([x, y, z], axis=-1) + + +def xyz_to_angles(xyz): + r""" + Convert a point :math:`\vec r = (x, y, z)` on the sphere into angles :math:`(\alpha, \beta)`. + + .. math:: + \vec r = R(\alpha, \beta, 0) \vec e_z + + Args: + xyz (Tensor): The point :math:`(x, y, z)` on the sphere. The shape of Tensor is :math:`(..., 3)`. + + Returns: + alpha (Tensor) - The alpha Euler angles. The shape of Tensor is :math:`(...)`. + beta (Tensor) - The beta Euler angles. The shape of Tensor is :math:`(...)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> from mindchemistry.e3.o3 import xyz_to_angles + >>> input = ms.Tensor([3, 3, 3]) + >>> m = xyz_to_angles(input) + >>> print(m) + (Tensor(shape=[], dtype=Float32, value= 0.785398), Tensor(shape=[], dtype=Float32, value= 0.955318)) + """ + xyz = xyz / norm_keep(xyz, axis=-1) + xyz = ops.nan_to_num(ops.clamp(xyz, -1, 1), 1.0) + + beta = ops.acos(xyz[..., 1]) + alpha = ops.atan2(xyz[..., 0], xyz[..., 2]) + return alpha, beta diff --git a/mindscience/e3nn/o3/spherical_harmonics.py b/mindscience/e3nn/o3/spherical_harmonics.py new file mode 100644 index 0000000000000000000000000000000000000000..f6725802beded7af8d0139a8d8cf3de355fff6d8 --- /dev/null +++ b/mindscience/e3nn/o3/spherical_harmonics.py @@ -0,0 +1,725 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SphericalHarmonics""" +from mindspore import Tensor, nn, ops, float32 +from .irreps import Irreps + + +def _sqrt(x, dtype=float32): + sqrt = ops.Sqrt() + return sqrt(Tensor(x, dtype=dtype)) + + +class SphericalHarmonics(nn.Cell): + r""" + Return Spherical harmonics layer. + + Args: + irreps_out (Union[str, `Irreps`]): irreducible representations of output for spherical harmonics. + normalize (bool): whether to normalize the input Tensor to unit vectors that lie on the sphere before + projecting onto the spherical harmonics. + normalization (str): {'integral', 'component', 'norm'}, normalization method of the output tensors. + Default: ``'integral'``. + irreps_in (Union[str, `Irreps`, None]): irreducible representations of input for spherical harmonics. + Default: ``None``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + + Inputs: + - **x** (Tensor) - Tensor for construct spherical harmonics. The shape of Tensor is :math:`(..., 3)`. + + Outputs: + - **output** (Tensor) - the spherical harmonics :math:`Y^l(x)`. The shape of Tensor is :math:`(..., 2l+1)`. + + Raise: + ValueError: If `normalization` is not in {'integral', 'component', 'norm'}. + ValueError: If `irreps_in` for SphericalHarmonics is not neither a vector (`1x1o`) nor a pseudovector (`1x1e`). + ValueError: If the `l` and `p` of `irreps_out` are not consistent with `irreps_in` for spherical harmonics. + The output parity should have been p = {input_p**l}. + NotImplementedError: If `l` is larger than 11. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import SphericalHarmonics + >>> from mindspore import ops + >>> sh = SphericalHarmonics(0, False, normalization='component') + >>> x = ops.rand(2,3) + >>> m = sh(x) + [[1.] + [1.]] + """ + + def __init__(self, irreps_out, normalize, normalization='integral', irreps_in=None, dtype=float32): + super().__init__() + self.normalize = normalize + self.normalization = normalization + if normalization not in ['integral', 'component', 'norm']: + raise ValueError + + if isinstance(irreps_out, str): + irreps_out = Irreps(irreps_out) + if isinstance(irreps_out, Irreps) and irreps_in is None: + for mul, (l, p) in irreps_out: + if l % 2 == 1 and p == 1: + irreps_in = Irreps("1e") + if irreps_in is None: + irreps_in = Irreps("1o") + + irreps_in = Irreps(irreps_in) + if irreps_in not in (Irreps("1x1o"), Irreps("1x1e")): + raise ValueError + self.irreps_in = irreps_in + input_p = irreps_in.data[0].ir.p + + if isinstance(irreps_out, Irreps): + ls = [] + for mul, (l, p) in irreps_out: + if p != input_p ** l: + raise ValueError + ls.extend([l] * mul) + elif isinstance(irreps_out, int): + ls = [irreps_out] + else: + ls = list(irreps_out) + + irreps_out = Irreps([(1, (l, input_p ** l)) for l in ls]).simplify() + self.irreps_out = irreps_out + self._ls_list = ls + self._lmax = max(ls) + self._is_range_lmax = ls == list(range(max(ls) + 1)) + self._prof_str = f'spherical_harmonics({ls})' + self.ones = ops.Ones() + + if self.normalization == 'integral': + self.norm_factors = [ + (_sqrt(2 * l + 1., dtype) / 3.5449077018110318) * + self.ones(2 * l + 1, dtype) + for l in self._ls_list + ] + elif self.normalization == 'component': + self.norm_factors = [ + _sqrt(2 * l + 1., dtype) * self.ones(2 * l + 1, dtype) + for l in self._ls_list + ] + + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=0.000000000001) + + def construct(self, x): + """ + Compute spherical harmonics of vector `x`. + + Args: + x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^l(x)`. The shape of Tensor is ``(..., 2l+1)`` + + Examples: + >>> sh = SphericalHarmonics(irreps_out="1o + 2x2e", normalize=True) + >>> input = ops.ones([1,3]) + >>> output = sh(input) + >>> print(output) + [[0.28209478 0.28209478 0.28209478 0.36418277 0.36418277 0 + 0.36418277 0 0.36418277 0.36418277 0 0.36418277 + 0]] + """ + last_dim = x.shape[-1] + if not last_dim == 3: + raise ValueError + + if self.normalize: + x = self.l2_normalize(x) + + sh = _spherical_harmonics(self._lmax, x[..., 0], x[..., 1], x[..., 2]) + + if not self._is_range_lmax: + sh = ops.concat([ + sh[..., l * l:(l + 1) * (l + 1)] + for l in self._ls_list + ], axis=-1) + if self.normalization != 'norm': + sh = ops.mul(sh, ops.concat(self.norm_factors)) + + return sh + + def __repr__(self): + return f'SphericalHarmonics {self._ls_list} ({self.irreps_in} -> {self.irreps_out})' + + +def spherical_harmonics(l, x, normalize=True, normalization='integral'): + r""" + Compute spherical harmonics. + + Spherical harmonics are polynomials defined on the 3d space : + math:`Y^l: \mathbb{R}^3 \longrightarrow \mathbb{R}^{2l+1}` + Usually restricted on the sphere (with ``normalize=True``) : + math:`Y^l: S^2 \longrightarrow \mathbb{R}^{2l+1}` + who satisfies the following properties: + - are polynomials of the cartesian coordinates ``x, y, z`` + - is equivariant :math:`Y^l(R x) = D^l(R) Y^l(x)` + - are orthogonal :math:`\int_{S^2} Y^l_m(x) Y^j_n(x) dx = \text{cste} \; \delta_{lj} \delta_{mn}` + The value of the constant depends on the choice of normalization. + + It obeys the following property: + .. math:: + Y^{l+1}_i(x) &= \text{cste}(l) \; & C_{ijk} Y^l_j(x) x_k + \partial_k Y^{l+1}_i(x) &= \text{cste}(l) \; (l+1) & C_{ijk} Y^l_j(x) + Where :math:`C` are the `wigner_3j`. + + Args: + l (Union[int, List[int]]): degree of the spherical harmonics. + x (Tensor): tensor for construct spherical harmonics. + The shape of Tensor is :math:`x` of shape ``(..., 3)`` + normalize (bool): whether to normalize the ``x`` to unit vectors that lie on the sphere before projecting onto + the spherical harmonics. + normalization (str): {'integral', 'component', 'norm'}, normalization method of the output tensors. + Default: 'intergral'. + 'component': :math:`\|Y^l(x)\|^2 = 2l+1, x \in S^2` + 'norm': :math:`\|Y^l(x)\| = 1, x \in S^2`, ``component / sqrt(2l+1)`` + 'integral': :math:`\int_{S^2} Y^l_m(x)^2 dx = 1`, ``component / sqrt(4pi)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^l(x)`. The shape of Tensor is ``(..., 2l+1)``. + + Raise: + ValueError: If `normalization` is not in {'integral', 'component', 'norm'}. + ValueError: If `irreps_in` for SphericalHarmonics is not neither a vector (`1x1o`) nor a pseudovector (`1x1e`). + ValueError: If the `l` and `p` of `irreps_out` are not consistent with `irreps_in` for spherical harmonics. + The output parity should have been p = {input_p**l}. + ValueError: If the tensor `x` is not the shape of ``(..., 3)``. + NotImplementedError: If `l` is larger than 11. + + """ + sh = SphericalHarmonics(l, normalize, normalization, dtype=x.dtype) + return sh(x) + +def _sh0(x): + """ + Compute spherical harmonics of degree 0. + + Args: + x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^0(x)`. The shape of Tensor is ``(..., 1)``. + """ + sh_0_0 = ops.ones_like(x) + return [sh_0_0] + +def _sh1(x, y, z): + """ + Compute spherical harmonics of degree 1. + + Args: + x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^1(x)`. The shape of Tensor is ``(..., 3)``. + """ + sh_1_0 = x + sh_1_1 = y + sh_1_2 = z + return [sh_1_0, sh_1_1, sh_1_2] + +def _sh2(x, y, z): + """ + Compute spherical harmonics of degree 2. + + Args: + x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^2(x)`. The shape of Tensor is ``(..., 5)``. + """ + sh_2_0 = 1.7320508075688772 * x * z + sh_2_1 = 1.7320508075688772 * x * y + y2 = y.pow(2) + x2z2 = x.pow(2) + z.pow(2) + sh_2_2 = y2 - 0.5 * x2z2 + sh_2_3 = 1.7320508075688772 * y * z + sh_2_4 = 1.7320508075688772 / 2.0 * (z.pow(2) - x.pow(2)) + return [sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4] + +def _sh3(x, y, z, prev): + """Compute spherical harmonics of degree 3.""" + sh_2_0, sh_2_4 = prev[0], prev[4] + y2 = y.pow(2) + x2z2 = x.pow(2) + z.pow(2) + sh_3_0 = 0.9128709291752769 * (sh_2_0 * z + sh_2_4 * x) + sh_3_1 = 2.23606797749979 * sh_2_0 * y + sh_3_2 = 0.6123724356957945 * (4.0 * y2 - x2z2) * x + sh_3_3 = 0.5 * y * (2.0 * y2 - 3.0 * x2z2) + sh_3_4 = 0.6123724356957945 * z * (4.0 * y2 - x2z2) + sh_3_5 = 2.23606797749979 * sh_2_4 * y + sh_3_6 = 0.9128709291752769 * (sh_2_4 * z - sh_2_0 * x) + return [sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6] + +def _sh4(x, y, z, prev): + """Compute spherical harmonics of degree 4.""" + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6 = prev + sh_4_0 = 0.935414346693485 * sh_3_0 * z + 0.935414346693485 * sh_3_6 * x + sh_4_1 = 0.661437827766148 * sh_3_0 * y + 0.810092587300982 * \ + sh_3_1 * z + 0.810092587300983 * sh_3_5 * x + sh_4_2 = -0.176776695296637 * sh_3_0 * z + 0.866025403784439 * sh_3_1 * y + \ + 0.684653196881458 * sh_3_2 * z + 0.684653196881457 * \ + sh_3_4 * x + 0.176776695296637 * sh_3_6 * x + sh_4_3 = -0.306186217847897 * sh_3_1 * z + 0.968245836551855 * sh_3_2 * \ + y + 0.790569415042095 * sh_3_3 * x + 0.306186217847897 * sh_3_5 * x + sh_4_4 = -0.612372435695795 * sh_3_2 * x + \ + sh_3_3 * y - 0.612372435695795 * sh_3_4 * z + sh_4_5 = -0.306186217847897 * sh_3_1 * x + 0.790569415042096 * sh_3_3 * \ + z + 0.968245836551854 * sh_3_4 * y - 0.306186217847897 * sh_3_5 * z + sh_4_6 = -0.176776695296637 * sh_3_0 * x - 0.684653196881457 * sh_3_2 * x + \ + 0.684653196881457 * sh_3_4 * z + 0.866025403784439 * \ + sh_3_5 * y - 0.176776695296637 * sh_3_6 * z + sh_4_7 = -0.810092587300982 * sh_3_1 * x + 0.810092587300982 * \ + sh_3_5 * z + 0.661437827766148 * sh_3_6 * y + sh_4_8 = -0.935414346693485 * sh_3_0 * x + 0.935414346693486 * sh_3_6 * z + return [sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8] + +def _sh5(x, y, z, prev): + """Compute spherical harmonics of degree 5.""" + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8 = prev + sh_5_0 = 0.948683298050513 * sh_4_0 * z + 0.948683298050513 * sh_4_8 * x + sh_5_1 = 0.6 * sh_4_0 * y + 0.848528137423857 * \ + sh_4_1 * z + 0.848528137423858 * sh_4_7 * x + sh_5_2 = -0.14142135623731 * sh_4_0 * z + 0.8 * sh_4_1 * y + 0.748331477354788 * \ + sh_4_2 * z + 0.748331477354788 * sh_4_6 * x + 0.14142135623731 * sh_4_8 * x + sh_5_3 = -0.244948974278318 * sh_4_1 * z + 0.916515138991168 * sh_4_2 * y + \ + 0.648074069840786 * sh_4_3 * z + 0.648074069840787 * \ + sh_4_5 * x + 0.244948974278318 * sh_4_7 * x + sh_5_4 = -0.346410161513776 * sh_4_2 * z + 0.979795897113272 * sh_4_3 * \ + y + 0.774596669241484 * sh_4_4 * x + 0.346410161513776 * sh_4_6 * x + sh_5_5 = -0.632455532033676 * sh_4_3 * x + \ + sh_4_4 * y - 0.632455532033676 * sh_4_5 * z + sh_5_6 = -0.346410161513776 * sh_4_2 * x + 0.774596669241483 * sh_4_4 * \ + z + 0.979795897113273 * sh_4_5 * y - 0.346410161513776 * sh_4_6 * z + sh_5_7 = -0.244948974278318 * sh_4_1 * x - 0.648074069840787 * sh_4_3 * x + \ + 0.648074069840786 * sh_4_5 * z + 0.916515138991169 * \ + sh_4_6 * y - 0.244948974278318 * sh_4_7 * z + sh_5_8 = -0.141421356237309 * sh_4_0 * x - 0.748331477354788 * sh_4_2 * x + \ + 0.748331477354788 * sh_4_6 * z + 0.8 * \ + sh_4_7 * y - 0.141421356237309 * sh_4_8 * z + sh_5_9 = -0.848528137423857 * sh_4_1 * x + \ + 0.848528137423857 * sh_4_7 * z + 0.6 * sh_4_8 * y + sh_5_10 = -0.948683298050513 * sh_4_0 * x + 0.948683298050513 * sh_4_8 * z + return [sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10] + +def _sh6(x, y, z, prev): + """Compute spherical harmonics of degree 6.""" + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10 = prev + + sh_6_0 = 0.957427107756337 * sh_5_0 * z + 0.957427107756338 * sh_5_10 * x + sh_6_1 = 0.552770798392565 * sh_5_0 * y + 0.874007373475125 * \ + sh_5_1 * z + 0.874007373475125 * sh_5_9 * x + sh_6_2 = -0.117851130197757 * sh_5_0 * z + 0.745355992499929 * sh_5_1 * y + \ + 0.117851130197758 * sh_5_10 * x + 0.790569415042094 * \ + sh_5_2 * z + 0.790569415042093 * sh_5_8 * x + sh_6_3 = -0.204124145231931 * sh_5_1 * z + 0.866025403784437 * sh_5_2 * y + \ + 0.707106781186546 * sh_5_3 * z + 0.707106781186547 * \ + sh_5_7 * x + 0.204124145231931 * sh_5_9 * x + sh_6_4 = -0.288675134594813 * sh_5_2 * z + 0.942809041582062 * sh_5_3 * y + \ + 0.623609564462323 * sh_5_4 * z + 0.623609564462322 * \ + sh_5_6 * x + 0.288675134594812 * sh_5_8 * x + sh_6_5 = -0.372677996249965 * sh_5_3 * z + 0.986013297183268 * sh_5_4 * \ + y + 0.763762615825972 * sh_5_5 * x + 0.372677996249964 * sh_5_7 * x + sh_6_6 = -0.645497224367901 * sh_5_4 * x + \ + sh_5_5 * y - 0.645497224367902 * sh_5_6 * z + sh_6_7 = -0.372677996249964 * sh_5_3 * x + 0.763762615825972 * sh_5_5 * \ + z + 0.986013297183269 * sh_5_6 * y - 0.372677996249965 * sh_5_7 * z + sh_6_8 = -0.288675134594813 * sh_5_2 * x - 0.623609564462323 * sh_5_4 * x + \ + 0.623609564462323 * sh_5_6 * z + 0.942809041582062 * \ + sh_5_7 * y - 0.288675134594812 * sh_5_8 * z + sh_6_9 = -0.20412414523193 * sh_5_1 * x - 0.707106781186546 * sh_5_3 * x + \ + 0.707106781186547 * sh_5_7 * z + 0.866025403784438 * \ + sh_5_8 * y - 0.204124145231931 * sh_5_9 * z + sh_6_10 = -0.117851130197757 * sh_5_0 * x - 0.117851130197757 * sh_5_10 * z - \ + 0.790569415042094 * sh_5_2 * x + 0.790569415042093 * \ + sh_5_8 * z + 0.745355992499929 * sh_5_9 * y + sh_6_11 = -0.874007373475124 * sh_5_1 * x + 0.552770798392566 * \ + sh_5_10 * y + 0.874007373475125 * sh_5_9 * z + sh_6_12 = -0.957427107756337 * sh_5_0 * x + 0.957427107756336 * sh_5_10 * z + return [sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12] + +def _sh7(x, y, z, prev): + """Compute spherical harmonics of degree 7.""" + sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12 = prev + + sh_7_0 = 0.963624111659433 * sh_6_0 * z + 0.963624111659432 * sh_6_12 * x + sh_7_1 = 0.515078753637713 * sh_6_0 * y + 0.892142571199771 * \ + sh_6_1 * z + 0.892142571199771 * sh_6_11 * x + sh_7_2 = -0.101015254455221 * sh_6_0 * z + 0.699854212223765 * sh_6_1 * y + \ + 0.82065180664829 * sh_6_10 * x + 0.101015254455222 * \ + sh_6_12 * x + 0.82065180664829 * sh_6_2 * z + sh_7_3 = -0.174963553055942 * sh_6_1 * z + 0.174963553055941 * sh_6_11 * x + \ + 0.82065180664829 * sh_6_2 * y + 0.749149177264394 * \ + sh_6_3 * z + 0.749149177264394 * sh_6_9 * x + sh_7_4 = 0.247435829652697 * sh_6_10 * x - 0.247435829652697 * sh_6_2 * z + \ + 0.903507902905251 * sh_6_3 * y + 0.677630927178938 * \ + sh_6_4 * z + 0.677630927178938 * sh_6_8 * x + sh_7_5 = -0.31943828249997 * sh_6_3 * z + 0.95831484749991 * sh_6_4 * y + \ + 0.606091526731326 * sh_6_5 * z + 0.606091526731326 * \ + sh_6_7 * x + 0.31943828249997 * sh_6_9 * x + sh_7_6 = -0.391230398217976 * sh_6_4 * z + 0.989743318610787 * sh_6_5 * \ + y + 0.755928946018454 * sh_6_6 * x + 0.391230398217975 * sh_6_8 * x + sh_7_7 = -0.654653670707977 * sh_6_5 * x + \ + sh_6_6 * y - 0.654653670707978 * sh_6_7 * z + sh_7_8 = -0.391230398217976 * sh_6_4 * x + 0.755928946018455 * sh_6_6 * \ + z + 0.989743318610787 * sh_6_7 * y - 0.391230398217975 * sh_6_8 * z + sh_7_9 = -0.31943828249997 * sh_6_3 * x - 0.606091526731327 * sh_6_5 * x + \ + 0.606091526731326 * sh_6_7 * z + 0.95831484749991 * \ + sh_6_8 * y - 0.31943828249997 * sh_6_9 * z + sh_7_10 = -0.247435829652697 * sh_6_10 * z - 0.247435829652697 * sh_6_2 * x - \ + 0.677630927178938 * sh_6_4 * x + 0.677630927178938 * \ + sh_6_8 * z + 0.903507902905251 * sh_6_9 * y + sh_7_11 = -0.174963553055942 * sh_6_1 * x + 0.820651806648289 * sh_6_10 * y - \ + 0.174963553055941 * sh_6_11 * z - 0.749149177264394 * \ + sh_6_3 * x + 0.749149177264394 * sh_6_9 * z + sh_7_12 = -0.101015254455221 * sh_6_0 * x + 0.82065180664829 * sh_6_10 * z + \ + 0.699854212223766 * sh_6_11 * y - 0.101015254455221 * \ + sh_6_12 * z - 0.82065180664829 * sh_6_2 * x + sh_7_13 = -0.892142571199772 * sh_6_1 * x + 0.892142571199772 * \ + sh_6_11 * z + 0.515078753637713 * sh_6_12 * y + sh_7_14 = -0.963624111659431 * sh_6_0 * x + 0.963624111659433 * sh_6_12 * z + + return [sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, + sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14] + +def _sh8(x, y, z, prev): + """Compute spherical harmonics of degree 8.""" + sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, \ + sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14 = prev + + sh_8_0 = 0.968245836551854 * sh_7_0 * z + \ + 0.968245836551853 * sh_7_14 * x + sh_8_1 = 0.484122918275928 * sh_7_0 * y + 0.90571104663684 * \ + sh_7_1 * z + 0.90571104663684 * sh_7_13 * x + sh_8_2 = -0.0883883476483189 * sh_7_0 * z + 0.661437827766148 * sh_7_1 * y + \ + 0.843171097702002 * sh_7_12 * x + 0.088388347648318 * \ + sh_7_14 * x + 0.843171097702003 * sh_7_2 * z + sh_8_3 = -0.153093108923948 * sh_7_1 * z + 0.7806247497998 * sh_7_11 * x + \ + 0.153093108923949 * sh_7_13 * x + 0.7806247497998 * \ + sh_7_2 * y + 0.780624749799799 * sh_7_3 * z + sh_8_4 = 0.718070330817253 * sh_7_10 * x + 0.21650635094611 * sh_7_12 * x - \ + 0.21650635094611 * sh_7_2 * z + 0.866025403784439 * \ + sh_7_3 * y + 0.718070330817254 * sh_7_4 * z + sh_8_5 = 0.279508497187474 * sh_7_11 * x - 0.279508497187474 * sh_7_3 * z + \ + 0.927024810886958 * sh_7_4 * y + 0.655505530106345 * \ + sh_7_5 * z + 0.655505530106344 * sh_7_9 * x + sh_8_6 = 0.342326598440729 * sh_7_10 * x - 0.342326598440729 * sh_7_4 * z + \ + 0.968245836551854 * sh_7_5 * y + 0.592927061281572 * \ + sh_7_6 * z + 0.592927061281571 * sh_7_8 * x + sh_8_7 = -0.405046293650492 * sh_7_5 * z + 0.992156741649221 * \ + sh_7_6 * y + 0.75 * sh_7_7 * x + 0.405046293650492 * sh_7_9 * x + sh_8_8 = -0.661437827766148 * sh_7_6 * x + \ + sh_7_7 * y - 0.661437827766148 * sh_7_8 * z + sh_8_9 = -0.405046293650492 * sh_7_5 * x + 0.75 * sh_7_7 * z + \ + 0.992156741649221 * sh_7_8 * y - 0.405046293650491 * sh_7_9 * z + sh_8_10 = -0.342326598440728 * sh_7_10 * z - 0.342326598440729 * sh_7_4 * x - \ + 0.592927061281571 * sh_7_6 * x + 0.592927061281571 * \ + sh_7_8 * z + 0.968245836551855 * sh_7_9 * y + sh_8_11 = 0.927024810886958 * sh_7_10 * y - 0.279508497187474 * sh_7_11 * z - \ + 0.279508497187474 * sh_7_3 * x - 0.655505530106345 * \ + sh_7_5 * x + 0.655505530106345 * sh_7_9 * z + sh_8_12 = 0.718070330817253 * sh_7_10 * z + 0.866025403784439 * sh_7_11 * y - \ + 0.216506350946109 * sh_7_12 * z - 0.216506350946109 * \ + sh_7_2 * x - 0.718070330817254 * sh_7_4 * x + sh_8_13 = -0.153093108923948 * sh_7_1 * x + 0.7806247497998 * sh_7_11 * z + \ + 0.7806247497998 * sh_7_12 * y - 0.153093108923948 * \ + sh_7_13 * z - 0.780624749799799 * sh_7_3 * x + sh_8_14 = -0.0883883476483179 * sh_7_0 * x + 0.843171097702002 * sh_7_12 * z + \ + 0.661437827766147 * sh_7_13 * y - 0.088388347648319 * \ + sh_7_14 * z - 0.843171097702002 * sh_7_2 * x + sh_8_15 = -0.90571104663684 * sh_7_1 * x + 0.90571104663684 * \ + sh_7_13 * z + 0.484122918275927 * sh_7_14 * y + sh_8_16 = -0.968245836551853 * sh_7_0 * x + 0.968245836551855 * sh_7_14 * z + + return [sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, + sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, + sh_8_15, sh_8_16] + +def _sh9(x, y, z, prev): + """Compute spherical harmonics of degree 9.""" + sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, \ + sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, \ + sh_8_16 = prev + + sh_9_0 = 0.97182531580755 * sh_8_0 * z + 0.971825315807551 * sh_8_16 * x + sh_9_1 = 0.458122847290851 * sh_8_0 * y + 0.916245694581702 * \ + sh_8_1 * z + 0.916245694581702 * sh_8_15 * x + sh_9_2 = -0.078567420131839 * sh_8_0 * z + 0.62853936105471 * sh_8_1 * y + 0.86066296582387 * \ + sh_8_14 * x + 0.0785674201318385 * sh_8_16 * x + 0.860662965823871 * sh_8_2 * z + sh_9_3 = -0.136082763487955 * sh_8_1 * z + 0.805076485899413 * sh_8_13 * x + \ + 0.136082763487954 * sh_8_15 * x + 0.74535599249993 * \ + sh_8_2 * y + 0.805076485899413 * sh_8_3 * z + sh_9_4 = 0.749485420179558 * sh_8_12 * x + 0.192450089729875 * sh_8_14 * x - \ + 0.192450089729876 * sh_8_2 * z + 0.831479419283099 * \ + sh_8_3 * y + 0.749485420179558 * sh_8_4 * z + sh_9_5 = 0.693888666488711 * sh_8_11 * x + 0.248451997499977 * sh_8_13 * x - \ + 0.248451997499976 * sh_8_3 * z + 0.895806416477617 * \ + sh_8_4 * y + 0.69388866648871 * sh_8_5 * z + sh_9_6 = 0.638284738504225 * sh_8_10 * x + 0.304290309725092 * sh_8_12 * x - \ + 0.304290309725092 * sh_8_4 * z + 0.942809041582063 * \ + sh_8_5 * y + 0.638284738504225 * sh_8_6 * z + sh_9_7 = 0.360041149911548 * sh_8_11 * x - 0.360041149911548 * sh_8_5 * z + \ + 0.974996043043569 * sh_8_6 * y + 0.582671582316751 * \ + sh_8_7 * z + 0.582671582316751 * sh_8_9 * x + sh_9_8 = 0.415739709641549 * sh_8_10 * x - 0.415739709641549 * sh_8_6 * \ + z + 0.993807989999906 * sh_8_7 * y + 0.74535599249993 * sh_8_8 * x + sh_9_9 = -0.66666666666666666667 * sh_8_7 * x + \ + sh_8_8 * y - 0.66666666666666666667 * sh_8_9 * z + sh_9_10 = -0.415739709641549 * sh_8_10 * z - 0.415739709641549 * sh_8_6 * \ + x + 0.74535599249993 * sh_8_8 * z + 0.993807989999906 * sh_8_9 * y + sh_9_11 = 0.974996043043568 * sh_8_10 * y - 0.360041149911547 * sh_8_11 * z - \ + 0.360041149911548 * sh_8_5 * x - 0.582671582316751 * \ + sh_8_7 * x + 0.582671582316751 * sh_8_9 * z + sh_9_12 = 0.638284738504225 * sh_8_10 * z + 0.942809041582063 * sh_8_11 * y - \ + 0.304290309725092 * sh_8_12 * z - 0.304290309725092 * \ + sh_8_4 * x - 0.638284738504225 * sh_8_6 * x + sh_9_13 = 0.693888666488711 * sh_8_11 * z + 0.895806416477617 * sh_8_12 * y - \ + 0.248451997499977 * sh_8_13 * z - 0.248451997499977 * \ + sh_8_3 * x - 0.693888666488711 * sh_8_5 * x + sh_9_14 = 0.749485420179558 * sh_8_12 * z + 0.831479419283098 * sh_8_13 * y - \ + 0.192450089729875 * sh_8_14 * z - 0.192450089729875 * \ + sh_8_2 * x - 0.749485420179558 * sh_8_4 * x + sh_9_15 = -0.136082763487954 * sh_8_1 * x + 0.805076485899413 * sh_8_13 * z + \ + 0.745355992499929 * sh_8_14 * y - 0.136082763487955 * \ + sh_8_15 * z - 0.805076485899413 * sh_8_3 * x + sh_9_16 = -0.0785674201318389 * sh_8_0 * x + 0.86066296582387 * sh_8_14 * z + \ + 0.628539361054709 * sh_8_15 * y - 0.0785674201318387 * \ + sh_8_16 * z - 0.860662965823871 * sh_8_2 * x + sh_9_17 = -0.9162456945817 * sh_8_1 * x + 0.916245694581702 * \ + sh_8_15 * z + 0.458122847290851 * sh_8_16 * y + sh_9_18 = -0.97182531580755 * sh_8_0 * x + 0.97182531580755 * sh_8_16 * z + + return [sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, + sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, + sh_9_15, sh_9_16, sh_9_17, sh_9_18] + +def _sh10(x, y, z, prev): + """Compute spherical harmonics of degree 10.""" + sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, \ + sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, \ + sh_9_16, sh_9_17, sh_9_18 = prev + + sh_10_0 = 0.974679434480897 * sh_9_0 * z + 0.974679434480897 * sh_9_18 * x + sh_10_1 = 0.435889894354067 * sh_9_0 * y + 0.924662100445347 * \ + sh_9_1 * z + 0.924662100445347 * sh_9_17 * x + sh_10_2 = -0.0707106781186546 * sh_9_0 * z + 0.6 * sh_9_1 * y + 0.874642784226796 * \ + sh_9_16 * x + 0.070710678118655 * sh_9_18 * x + 0.874642784226795 * sh_9_2 * z + sh_10_3 = -0.122474487139159 * sh_9_1 * z + 0.824621125123533 * sh_9_15 * x + \ + 0.122474487139159 * sh_9_17 * x + 0.714142842854285 * \ + sh_9_2 * y + 0.824621125123533 * sh_9_3 * z + sh_10_4 = 0.774596669241484 * sh_9_14 * x + 0.173205080756887 * sh_9_16 * x - \ + 0.173205080756888 * sh_9_2 * z + 0.8 * \ + sh_9_3 * y + 0.774596669241483 * sh_9_4 * z + sh_10_5 = 0.724568837309472 * sh_9_13 * x + 0.223606797749979 * sh_9_15 * x - \ + 0.223606797749979 * sh_9_3 * z + 0.866025403784438 * \ + sh_9_4 * y + 0.724568837309472 * sh_9_5 * z + sh_10_6 = 0.674536878161602 * sh_9_12 * x + 0.273861278752583 * sh_9_14 * x - \ + 0.273861278752583 * sh_9_4 * z + 0.916515138991168 * \ + sh_9_5 * y + 0.674536878161602 * sh_9_6 * z + sh_10_7 = 0.62449979983984 * sh_9_11 * x + 0.324037034920393 * sh_9_13 * x - \ + 0.324037034920393 * sh_9_5 * z + 0.953939201416946 * \ + sh_9_6 * y + 0.62449979983984 * sh_9_7 * z + sh_10_8 = 0.574456264653803 * sh_9_10 * x + 0.374165738677394 * sh_9_12 * x - \ + 0.374165738677394 * sh_9_6 * z + 0.979795897113272 * \ + sh_9_7 * y + 0.574456264653803 * sh_9_8 * z + sh_10_9 = 0.424264068711928 * sh_9_11 * x - 0.424264068711929 * sh_9_7 * \ + z + 0.99498743710662 * sh_9_8 * y + 0.741619848709567 * sh_9_9 * x + sh_10_10 = -0.670820393249937 * sh_9_10 * z - \ + 0.670820393249937 * sh_9_8 * x + sh_9_9 * y + sh_10_11 = 0.99498743710662 * sh_9_10 * y - 0.424264068711929 * sh_9_11 * \ + z - 0.424264068711929 * sh_9_7 * x + 0.741619848709567 * sh_9_9 * z + sh_10_12 = 0.574456264653803 * sh_9_10 * z + 0.979795897113272 * sh_9_11 * y - \ + 0.374165738677395 * sh_9_12 * z - 0.374165738677394 * \ + sh_9_6 * x - 0.574456264653803 * sh_9_8 * x + sh_10_13 = 0.62449979983984 * sh_9_11 * z + 0.953939201416946 * sh_9_12 * y - \ + 0.324037034920393 * sh_9_13 * z - 0.324037034920393 * \ + sh_9_5 * x - 0.62449979983984 * sh_9_7 * x + sh_10_14 = 0.674536878161602 * sh_9_12 * z + 0.916515138991168 * sh_9_13 * y - \ + 0.273861278752583 * sh_9_14 * z - 0.273861278752583 * \ + sh_9_4 * x - 0.674536878161603 * sh_9_6 * x + sh_10_15 = 0.724568837309472 * sh_9_13 * z + 0.866025403784439 * sh_9_14 * y - \ + 0.223606797749979 * sh_9_15 * z - 0.223606797749979 * \ + sh_9_3 * x - 0.724568837309472 * sh_9_5 * x + sh_10_16 = 0.774596669241484 * sh_9_14 * z + 0.8 * sh_9_15 * y - 0.173205080756888 * \ + sh_9_16 * z - 0.173205080756887 * sh_9_2 * x - 0.774596669241484 * sh_9_4 * x + sh_10_17 = -0.12247448713916 * sh_9_1 * x + 0.824621125123532 * sh_9_15 * z + \ + 0.714142842854285 * sh_9_16 * y - 0.122474487139158 * \ + sh_9_17 * z - 0.824621125123533 * sh_9_3 * x + sh_10_18 = -0.0707106781186548 * sh_9_0 * x + 0.874642784226796 * sh_9_16 * z + \ + 0.6 * sh_9_17 * y - 0.0707106781186546 * \ + sh_9_18 * z - 0.874642784226796 * sh_9_2 * x + sh_10_19 = -0.924662100445348 * sh_9_1 * x + 0.924662100445347 * \ + sh_9_17 * z + 0.435889894354068 * sh_9_18 * y + sh_10_20 = -0.974679434480898 * sh_9_0 * x + 0.974679434480896 * sh_9_18 * z + + return [sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, + sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, + sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, + sh_10_19, sh_10_20] + +def _sh11(x, y, z, prev): + """Compute spherical harmonics of degree 11.""" + sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, \ + sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, \ + sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20 = prev + + sh_11_0 = 0.977008420918394 * sh_10_0 * z + 0.977008420918394 * sh_10_20 * x + sh_11_1 = 0.416597790450531 * sh_10_0 * y + 0.9315409787236 * \ + sh_10_1 * z + 0.931540978723599 * sh_10_19 * x + sh_11_2 = -0.0642824346533223 * sh_10_0 * z + 0.574959574576069 * sh_10_1 * y + \ + 0.88607221316445 * sh_10_18 * x + 0.886072213164452 * \ + sh_10_2 * z + 0.0642824346533226 * sh_10_20 * x + sh_11_3 = -0.111340442853781 * sh_10_1 * z + 0.84060190949577 * sh_10_17 * x + \ + 0.111340442853781 * sh_10_19 * x + 0.686348585024614 * \ + sh_10_2 * y + 0.840601909495769 * sh_10_3 * z + sh_11_4 = 0.795129803842541 * sh_10_16 * x + 0.157459164324444 * sh_10_18 * x - \ + 0.157459164324443 * sh_10_2 * z + 0.771389215839871 * \ + sh_10_3 * y + 0.795129803842541 * sh_10_4 * z + sh_11_5 = 0.74965556829412 * sh_10_15 * x + 0.203278907045435 * sh_10_17 * x - \ + 0.203278907045436 * sh_10_3 * z + 0.838140405208444 * \ + sh_10_4 * y + 0.74965556829412 * sh_10_5 * z + sh_11_6 = 0.70417879021953 * sh_10_14 * x + 0.248964798865985 * sh_10_16 * x - \ + 0.248964798865985 * sh_10_4 * z + 0.890723542830247 * \ + sh_10_5 * y + 0.704178790219531 * sh_10_6 * z + sh_11_7 = 0.658698943008611 * sh_10_13 * x + 0.294579122654903 * sh_10_15 * x - \ + 0.294579122654903 * sh_10_5 * z + 0.9315409787236 * \ + sh_10_6 * y + 0.658698943008611 * sh_10_7 * z + sh_11_8 = 0.613215343783275 * sh_10_12 * x + 0.340150671524904 * sh_10_14 * x - \ + 0.340150671524904 * sh_10_6 * z + 0.962091385841669 * \ + sh_10_7 * y + 0.613215343783274 * sh_10_8 * z + sh_11_9 = 0.567727090763491 * sh_10_11 * x + 0.385694607919935 * sh_10_13 * x - \ + 0.385694607919935 * sh_10_7 * z + 0.983332166035633 * \ + sh_10_8 * y + 0.56772709076349 * sh_10_9 * z + sh_11_10 = 0.738548945875997 * sh_10_10 * x + 0.431219680932052 * sh_10_12 * \ + x - 0.431219680932052 * sh_10_8 * z + 0.995859195463938 * sh_10_9 * y + sh_11_11 = sh_10_10 * y - 0.674199862463242 * \ + sh_10_11 * z - 0.674199862463243 * sh_10_9 * x + sh_11_12 = 0.738548945875996 * sh_10_10 * z + 0.995859195463939 * sh_10_11 * \ + y - 0.431219680932052 * sh_10_12 * z - 0.431219680932053 * sh_10_8 * x + sh_11_13 = 0.567727090763491 * sh_10_11 * z + 0.983332166035634 * sh_10_12 * y - \ + 0.385694607919935 * sh_10_13 * z - 0.385694607919935 * \ + sh_10_7 * x - 0.567727090763491 * sh_10_9 * x + sh_11_14 = 0.613215343783275 * sh_10_12 * z + 0.96209138584167 * sh_10_13 * y - \ + 0.340150671524904 * sh_10_14 * z - 0.340150671524904 * \ + sh_10_6 * x - 0.613215343783274 * sh_10_8 * x + sh_11_15 = 0.658698943008611 * sh_10_13 * z + 0.9315409787236 * sh_10_14 * y - \ + 0.294579122654903 * sh_10_15 * z - 0.294579122654903 * \ + sh_10_5 * x - 0.65869894300861 * sh_10_7 * x + sh_11_16 = 0.70417879021953 * sh_10_14 * z + 0.890723542830246 * sh_10_15 * y - \ + 0.248964798865985 * sh_10_16 * z - 0.248964798865985 * \ + sh_10_4 * x - 0.70417879021953 * sh_10_6 * x + sh_11_17 = 0.749655568294121 * sh_10_15 * z + 0.838140405208444 * sh_10_16 * y - \ + 0.203278907045436 * sh_10_17 * z - 0.203278907045435 * \ + sh_10_3 * x - 0.749655568294119 * sh_10_5 * x + sh_11_18 = 0.79512980384254 * sh_10_16 * z + 0.77138921583987 * sh_10_17 * y - \ + 0.157459164324443 * sh_10_18 * z - 0.157459164324444 * \ + sh_10_2 * x - 0.795129803842541 * sh_10_4 * x + sh_11_19 = -0.111340442853782 * sh_10_1 * x + 0.84060190949577 * sh_10_17 * z + \ + 0.686348585024614 * sh_10_18 * y - 0.111340442853781 * \ + sh_10_19 * z - 0.840601909495769 * sh_10_3 * x + sh_11_20 = -0.0642824346533226 * sh_10_0 * x + 0.886072213164451 * sh_10_18 * z + \ + 0.57495957457607 * sh_10_19 * y - 0.886072213164451 * \ + sh_10_2 * x - 0.0642824346533228 * sh_10_20 * z + sh_11_21 = -0.9315409787236 * sh_10_1 * x + 0.931540978723599 * \ + sh_10_19 * z + 0.416597790450531 * sh_10_20 * y + sh_11_22 = -0.977008420918393 * sh_10_0 * x + 0.977008420918393 * sh_10_20 * z + + return [sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, + sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, sh_11_12, + sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, + sh_11_19, sh_11_20, sh_11_21, sh_11_22] + +def _spherical_harmonics(lmax: int, x, y, z): + """Compute spherical harmonics up to degree lmax.""" + results = [] + + # l = 0 + sh0 = _sh0(x) + results.extend(sh0) + if lmax == 0: + return ops.stack(results, axis=-1) + + # l = 1 + sh1 = _sh1(x, y, z) + results.extend(sh1) + if lmax == 1: + return ops.stack(results, axis=-1) + + # l = 2 + sh2 = _sh2(x, y, z) + results.extend(sh2) + if lmax == 2: + return ops.stack(results, axis=-1) + + # l = 3 + sh3 = _sh3(x, y, z, sh2) + results.extend(sh3) + if lmax == 3: + return ops.stack(results, axis=-1) + + # l = 4 + sh4 = _sh4(x, y, z, sh3) + results.extend(sh4) + if lmax == 4: + return ops.stack(results, axis=-1) + + # l = 5 + sh5 = _sh5(x, y, z, sh4) + results.extend(sh5) + if lmax == 5: + return ops.stack(results, axis=-1) + + # l = 6 + sh6 = _sh6(x, y, z, sh5) + results.extend(sh6) + if lmax == 6: + return ops.stack(results, axis=-1) + + # l = 7 + sh7 = _sh7(x, y, z, sh6) + results.extend(sh7) + if lmax == 7: + return ops.stack(results, axis=-1) + + # l = 8 + sh8 = _sh8(x, y, z, sh7) + results.extend(sh8) + if lmax == 8: + return ops.stack(results, axis=-1) + + # l = 9 + sh9 = _sh9(x, y, z, sh8) + results.extend(sh9) + if lmax == 9: + return ops.stack(results, axis=-1) + + # l = 10 + sh10 = _sh10(x, y, z, sh9) + results.extend(sh10) + if lmax == 10: + return ops.stack(results, axis=-1) + + # l = 11 + sh11 = _sh11(x, y, z, sh10) + results.extend(sh11) + if lmax == 11: + return ops.stack(results, axis=-1) + + # 默认返回最高阶 (l=11) + return ops.stack(results, axis=-1) diff --git a/mindscience/e3nn/o3/sub.py b/mindscience/e3nn/o3/sub.py new file mode 100644 index 0000000000000000000000000000000000000000..03ebe60cf60b56a0acb5653605523d257c17794c --- /dev/null +++ b/mindscience/e3nn/o3/sub.py @@ -0,0 +1,503 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""sub""" +from typing import NamedTuple +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +from mindspore import ops, float32 +from .tensor_product import TensorProduct +from .irreps import Irreps +from ..utils.func import narrow + + +class FullyConnectedTensorProduct(TensorProduct): + r""" + Fully-connected weighted tensor product. All the possible path allowed + by :math:`|l_1 - l_2| \leq l_{out} \leq l_1 + l_2` are made. + Equivalent to `TensorProduct` with `instructions='connect'`. + For details, see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in1 (Union[str, Irrep, Irreps]): Irreps for the first input. + irreps_in2 (Union[str, Irrep, Irreps]): Irreps for the second input. + irreps_out (Union[str, Irrep, Irreps]): Irreps for the output. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: 'component'. Default: 'component'. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: 'element'. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: 'normal'. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import FullyConnectedTensorProduct + >>> FullyConnectedTensorProduct('2x1o', '1x1o+3x0e', '5x2e+4x1o') + TensorProduct [connect] (2x1o x 1x1o+3x0e -> 5x2e+4x1o) + + """ + + def __init__(self, + irreps_in1, + irreps_in2, + irreps_out, + ncon_dtype=float32, + **kwargs): + super().__init__(irreps_in1, + irreps_in2, + irreps_out, + instructions='connect', + ncon_dtype=ncon_dtype, + **kwargs) + + +class FullTensorProduct(TensorProduct): + r""" + Full tensor product between two irreps. + + Equivalent to `TensorProduct` with `instructions='full'`. + For details, see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in1 (Union[str, Irrep, Irreps]): Irreps for the first input. + irreps_in2 (Union[str, Irrep, Irreps]): Irreps for the second input. + filter_ir_out (Union[str, Irrep, Irreps, None]): Filter to select only specific `Irrep` + of the output. Default: None. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: 'component'. Default: 'component'. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: 'element'. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: 'normal'. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import FullTensorProduct + >>> FullTensorProduct('2x1o+4x0o', '1x1o+3x0e') + TensorProduct [full] (2x1o+4x0o x 1x1o+3x0e -> 2x0e+12x0o+6x1o+2x1e+4x1e+2x2e) + + """ + + def __init__(self, + irreps_in1, + irreps_in2, + filter_ir_out=None, + ncon_dtype=float32, + **kwargs): + super().__init__(irreps_in1, + irreps_in2, + filter_ir_out, + instructions='full', + ncon_dtype=ncon_dtype, + **kwargs) + + +class ElementwiseTensorProduct(TensorProduct): + r""" + Elementwise connected tensor product. + + Equivalent to `TensorProduct` with `instructions='element'`. + For details, see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in1 (Union[str, Irrep, Irreps]): Irreps for the first input. + irreps_in2 (Union[str, Irrep, Irreps]): Irreps for the second input. + filter_ir_out (Union[str, Irrep, Irreps, None]): Filter to select only specific `Irrep` of the output. + Default: None. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: 'component'. Default: 'component'. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: 'element'. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: 'normal'. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import ElementwiseTensorProduct + >>> ElementwiseTensorProduct('2x2e+4x1o', '3x1e+3x0o') + TensorProduct [element] (2x2e+1x1o+3x1o x 2x1e+1x1e+3x0o -> 2x1e+2x2e+2x3e+1x0o+1x1o+1x2o+3x1e) + + """ + + def __init__(self, + irreps_in1, + irreps_in2, + filter_ir_out=None, + ncon_dtype=float32, + **kwargs): + super().__init__(irreps_in1, + irreps_in2, + filter_ir_out, + instructions='element', + ncon_dtype=ncon_dtype, + **kwargs) + + +class Linear(TensorProduct): + r""" + Linear operation equivariant. + + Equivalent to `TensorProduct` with `instructions='linear'`. + For details, see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. + irreps_out (Union[str, Irrep, Irreps]): Irreps for the output. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: ``'component'``. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: ``'element'``. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: ``'normal'``. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import Linear + >>> Linear('2x2e+3x1o+3x0e', '3x2e+5x1o+2x0e') + TensorProduct [linear] (2x2e+3x1o+3x0e x 1x0e -> 3x2e+5x1o+2x0e) + + """ + + def __init__(self, irreps_in, irreps_out, ncon_dtype=float32, **kwargs): + super().__init__(irreps_in, + None, + irreps_out, + instructions='linear', + ncon_dtype=ncon_dtype, + **kwargs) + + +class Instruction(NamedTuple): + i_in: int + i_out: int + path_shape: tuple + path_weight: float + + +def _prod(x): + out = 1 + for i in x: + out *= i + return out + + +def prod(x): + """Compute the product of a sequence.""" + out = 1 + for a in x: + out *= a + return out + + +def _sum_tensors_withbias(xs, shape, dtype): + """sum tensors of same irrep.""" + if xs: + if len(xs[0].shape) == 1: + out = xs[0] + else: + out = xs[0].reshape(shape) + for x in xs[1:]: + if len(x.shape) == 1: + out = out + x + else: + out = out + x.reshape(shape) + return out + return ops.zeros(shape, dtype=dtype) + + +def _compose(tensors, ir_data, instructions, batch_shape): + """compose list of tensor `tensors` into a 1d-tensor by `ir_data`.""" + res = [] + for i_out, mir_out in enumerate(ir_data): + if mir_out.mul > 0: + res.append( + _sum_tensors_withbias([ + out for ins, out in zip(instructions, tensors) + if ins['i_out'] == i_out + ], + shape=batch_shape + (mir_out.dim,), + dtype=tensors[0].dtype)) + + if len(res) > 1: + res = ops.concat(res, axis=-1) + else: + res = res[0] + return res + + +def _run_continue(ir1_data, ir2_data, irout_data, ins): + """check trivial computations""" + mir_in1 = ir1_data[ins['indice_one']] + mir_in2 = ir2_data[ins['indice_two']] + mir_out = irout_data[ins['i_out']] + if mir_in1.dim == 0 or mir_in2.dim == 0 or mir_out.dim == 0: + return True + return False + + +class LinearBias(TensorProduct): + r""" + Linear operation equivariant with option to add bias. + + Equivalent to `TensorProduct` with `instructions='linear'` with option to add bias. For details, + see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. + irreps_out (Union[str, Irrep, Irreps]): Irreps for the output. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: ``'component'``. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: ``'element'``. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: ``'normal'``. + has_bias (bool): whether add bias to calculation + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import LinearBias + >>> LinearBias('2x2e+3x1o+3x0e', '3x2e+5x1o+2x0e') + TensorProduct [linear] (2x2e+3x1o+3x0e x 1x0e -> 3x2e+5x1o+2x0e) + + """ + + def __init__(self, + irreps_in, + irreps_out, + has_bias, + ncon_dtype=float32, + **kwargs): + super().__init__(irreps_in, + None, + irreps_out, + instructions='linear', + ncon_dtype=ncon_dtype, + **kwargs) + irreps_in = Irreps(irreps_in) + irreps_out = Irreps(irreps_out) + + biases = [has_bias and ir.is_scalar() for _, ir in irreps_out] + + is_scalar_num = biases.count(True) + + instructions = [ + Instruction(i_in=-1, + i_out=i_out, + path_shape=(mul_ir.dim,), + path_weight=1.0) + for i_out, (bias, mul_ir) in enumerate(zip(biases, irreps_out)) + if bias + ] + self.has_bias = has_bias + self.bias_numel = None + self.bias_instructions = None + if self.has_bias: + self.bias_instructions = [] + for i_out, (bias, mul_ir) in enumerate(zip(biases, self.irreps_out)): + if bias: + path_shape = (mul_ir.dim,) + path_weight = 1.0 + instruction = Instruction(i_in=-1, i_out=i_out, path_shape=path_shape, path_weight=path_weight) + self.bias_instructions.append(instruction) + + if is_scalar_num == 1: + self.bias_numel = sum(irreps_out.data[i.i_out].dim + for i in instructions if i.i_in == -1) + bias = ops.zeros((self.bias_numel)) + self.bias = Parameter(bias, name="bias") + self.instr.append({ + "i_out": self.bias_instructions[0].i_out, + "indice_one": self.bias_instructions[0].i_in + }) + else: + bias = ops.zeros((is_scalar_num, 1)) + self.bias = Parameter(bias, name="bias") + + for bias_instr in self.bias_instructions: + self.instr.append({ + "i_out": bias_instr.i_out, + "indice_one": bias_instr.i_in + }) + + self.bias_add = P.BiasAdd() + self.ncon_dtype = ncon_dtype + + def construct(self, v1, v2=None, weight=None): + """Implement tensor product for input tensors.""" + self._weight_check(weight) + + if self._in2_is_none: + if v2 is not None: + raise ValueError(f"This tensor product should input 1 tensor.") + + if self._mode == 'linear': + v2_shape = v1.shape[:-1] + (1,) + v2 = ops.ones(v2_shape, v1.dtype) + else: + v2 = v1.copy() + else: + if v2 is None: + raise ValueError( + f"This tensor product should input 2 tensors.") + if self._mode == 'linear': + v2_shape = v1.shape[:-1] + (1,) + v2 = ops.ones(v2_shape, v1.dtype) + + batch_shape = v1.shape[:-1] + + v2s = self.irreps_in2.decompose(v2, batch=True) + v1s = self.irreps_in1.decompose(v1, batch=True) + + weight = self._get_weights(weight) + + if not (v1.shape[-1] == self.irreps_in1.dim + and v2.shape[-1] == self.irreps_in2.dim): + raise ValueError(f"The shape of input tensors do not match.") + + v3_list = [] + weight_ind = 0 + fn = 0 + index_one = 'indice_one' + index_two = 'indice_two' + index_wigner = 'wigner_matrix' + + for ins in self.instr: + if ins[index_one] == -1 or _run_continue(self.irreps_in1.data, + self.irreps_in2.data, + self.irreps_out.data, ins): + continue + fn = self._ncons[ins['i_ncon']] + if ins['has_weight']: + l = _prod(ins['path_shape']) + w = narrow( + weight, -1, weight_ind, + l).reshape(( + (-1,) if self.weight_mode == 'custom' else ()) + + ins['path_shape']).astype(self.ncon_dtype) + weight_ind += l + if self.core_mode == 'einsum': + v3 = fn((ins[index_wigner].astype(self.ncon_dtype), + v1s[ins[index_one]].astype(self.ncon_dtype), + v2s[ins[index_two]].astype(self.ncon_dtype), w)) + else: + v3 = fn([ + ins[index_wigner].astype(self.ncon_dtype), + v1s[ins[index_one]].astype(self.ncon_dtype), + v2s[ins[index_two]].astype(self.ncon_dtype), w + ]) + else: + if self.core_mode == 'einsum': + v3 = fn((ins[index_wigner].astype(self.ncon_dtype), + v1s[ins[index_one]].astype(self.ncon_dtype), + v2s[ins[index_two]].astype(self.ncon_dtype))) + else: + v3 = fn([ + ins[index_wigner].astype(self.ncon_dtype), + v1s[ins[index_one]].astype(self.ncon_dtype), + v2s[ins[index_two]].astype(self.ncon_dtype) + ]) + v3_list.append(ins['path_weight'].astype(self.dtype) * + v3.astype(self.dtype)) + + if self.has_bias: + if len(self.bias_instructions) == 1: + v3_list.append(self.bias) + else: + for i in range(len(self.bias_instructions)): + v3_list.append(self.bias[i]) + + v_out = _compose(v3_list, self.irreps_out.data, self.instr, + batch_shape) + + return v_out + + +class TensorSquare(TensorProduct): + r""" + Compute the square tensor product of a tensor. + + Equivalent to `TensorProduct` with `irreps_in2=None and instructions='full' or 'connect'`. For details, + see :class:`mindchemistry.e3.o3.TensorProduct`. + + If `irreps_out` is given, this operation is fully connected. + If `irreps_out` is not given, the operation has no parameter and is like full tensor product. + + Args: + irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. + irreps_out (Union[str, Irrep, Irreps, None]): Irreps for the output. Default: ``None``. + filter_ir_out (Union[str, Irrep, Irreps, None]): Filter to select only specific `Irrep` of the output. + Default: ``None``. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: ``'component'``. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: ``'element'``. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: 'normal'. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Raises: + ValueError: If both `irreps_out` and `filter_ir_out` are not None. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import TensorSquare + >>> TensorSquare('2x1o', irreps_out='5x2e+4x1e+7x1o') + TensorProduct [connect] (2x1o x 2x1o -> 5x2e+4x1e) + >>> TensorSquare('2x1o+3x0e', filter_ir_out='5x2o+4x1e+2x0e') + TensorProduct [full] (2x1o+3x0e x 2x1o+3x0e -> 4x0e+9x0e+4x1e) + + """ + + def __init__(self, + irreps_in, + irreps_out=None, + filter_ir_out=None, + ncon_dtype=float32, + **kwargs): + if irreps_out is None: + super().__init__(irreps_in, + None, + filter_ir_out, + instructions='full', + ncon_dtype=ncon_dtype, + **kwargs) + else: + if filter_ir_out is None: + super().__init__(irreps_in, + None, + irreps_out, + instructions='connect', + ncon_dtype=ncon_dtype, + **kwargs) + else: + raise ValueError( + "Both `irreps_out` and `filter_ir_out` are not None, this is ambiguous." + ) diff --git a/mindscience/e3nn/o3/tensor_product.py b/mindscience/e3nn/o3/tensor_product.py new file mode 100644 index 0000000000000000000000000000000000000000..0f281fc11a1186895e21f7769f6c98ee7b297a75 --- /dev/null +++ b/mindscience/e3nn/o3/tensor_product.py @@ -0,0 +1,768 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor, nn, ops, Parameter, get_context, float32, int32, vmap +from mindspore.common.initializer import initializer +import mindspore as ms +from .irreps import Irreps +from .wigner import wigner_3j +from ..utils.ncon import Ncon +from ..utils.func import narrow +from ..utils.initializer import renormal_initializer +import numpy as np +from mindspore.numpy import tensordot + +def _prod(x): + out = 1 + for i in x: + out *= i + return out + + +sqrt = ops.Sqrt() +zeros = ops.Zeros() + +def _sqrt(x, dtype=float32): + """sqrt operator with producing a tensor""" + return sqrt(Tensor(x, dtype=dtype)) + + +def _sum_tensors(xs, shape, dtype): + """sum tensors of same irrep.""" + if len(xs) > 0: + out = xs[0].reshape(shape) + for x in xs[1:]: + out = out + x.reshape(shape) + return out + return zeros(shape, dtype) + + +def _compose(tensors, ir_data, instructions, batch_shape): + """compose list of tensor `tensors` into a 1d-tensor by `ir_data`.""" + res = [] + for i_out, mir_out in enumerate(ir_data): + if mir_out.mul > 0: + res.append(_sum_tensors([out for ins, out in zip(instructions, tensors) + if ins['i_out'] == i_out], shape=batch_shape + (mir_out.dim,), + dtype=tensors[0].dtype)) + if len(res) > 1: + res = ops.concat(res, axis=-1) + else: + res = res[0] + return res + + +def _connect_init(irreps_in1, irreps_in2, irreps_out): + """Input initial for 'connect' mode.""" + full_out = (irreps_in1 * irreps_in2).simplify() + irreps_out = full_out if irreps_out is None else Irreps(irreps_out) + + instr = [] + for i_1, (_, ir_1) in enumerate(irreps_in1.data): + for i_2, (_, ir_2) in enumerate(irreps_in2.data): + ir_out_list = list(ir_1 * ir_2) + for i_out, (_, ir_out) in enumerate(irreps_out.data): + if ir_out in ir_out_list: + instr.append((i_1, i_2, i_out, 'uvw', True)) + + return irreps_out, instr + + +def _full_init(irreps_in1, irreps_in2, irreps_out): + """Input initial for 'full' mode.""" + full_out = irreps_in1 * irreps_in2 + irreps_out = full_out.filter(irreps_out) + + instr = [] + for i_1, (mul_1, ir_1) in enumerate(irreps_in1.data): + for i_2, (mul_2, ir_2) in enumerate(irreps_in2.data): + ir_out_list = list(ir_1 * ir_2) + for i_out, (mul_out, ir_out) in enumerate(irreps_out.data): + if ir_out in ir_out_list and mul_out == mul_1 * mul_2: + instr.append((i_1, i_2, i_out, 'uvuv', False)) + + return irreps_out, instr + + +def _element_init(irreps_in1, irreps_in2, irreps_out): + """Input initial for 'element' mode.""" + irreps_out = None if irreps_out is None else Irreps(irreps_out) + + if not irreps_in1.num_irreps == irreps_in2.num_irreps: + raise ValueError( + f"The total multiplicities of irreps_in1 {irreps_in1} and irreps_in2 {irreps_in2} should be equal.") + + irreps_in1_list = list(Irreps(irreps_in1).simplify().data) + irreps_in2_list = list(Irreps(irreps_in2).simplify().data) + + i = 0 + while i < len(irreps_in1_list): + mul_1, ir_1 = irreps_in1_list[i] + mul_2, ir_2 = irreps_in2_list[i] + + if mul_1 < mul_2: + irreps_in2_list[i] = (mul_1, ir_2) + irreps_in2_list.insert(i + 1, (mul_2 - mul_1, ir_2)) + + if mul_2 < mul_1: + irreps_in1_list[i] = (mul_2, ir_1) + irreps_in1_list.insert(i + 1, (mul_1 - mul_2, ir_1)) + i += 1 + + out = [] + instr = [] + for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1_list, irreps_in2_list)): + for ir in ir_1 * ir_2: + if irreps_out is not None and ir not in irreps_out: + continue + + out.append((mul, ir)) + instr.append((i, i, len(out) - 1, 'uuu', False)) + + return Irreps(irreps_in1_list), Irreps(irreps_in2_list), Irreps(out), instr + + +def _linear_init(irreps_in1, irreps_out): + """Input initial for 'lnear' mode.""" + irreps_out = Irreps(irreps_out) + + instr = [] + for i_1, (_, ir_1) in enumerate(irreps_in1.data): + for i_out, (_, ir_out) in enumerate(irreps_out.data): + if ir_1 == ir_out: + instr.append((i_1, 0, i_out, 'uvw', True)) + + return irreps_out, instr + + +def _merge_init(irreps_in1, irreps_in2, irreps_out_filter): + """Input initial for 'merge' mode.""" + irreps_out_filter = Irreps( + irreps_out_filter) if irreps_out_filter is not None else irreps_in1 * irreps_in2 + + irreps_out_list = [] + instr = [] + for i_1, (mul, ir_1) in enumerate(irreps_in1.data): + for i_2, (_, ir_2) in enumerate(irreps_in2.data): + for ir in ir_1 * ir_2: + if ir in irreps_out_filter: + k = len(irreps_out_list) + irreps_out_list.append((mul, ir)) + instr.append((i_1, i_2, k, 'uvu', True)) + + irreps_out = Irreps(irreps_out_list) + irreps_out, p, _ = irreps_out.sort() + + instr = [(i_1, i_2, p[i_out], mode, train) + for i_1, i_2, i_out, mode, train in instr] + + return irreps_out, instr + + +def _raw_ins_check(mir_in1, mir_in2, mir_out, raw_ins): + """Check raw input instructions.""" + if not mir_in1.ir.p * mir_in2.ir.p == mir_out.ir.p: + raise ValueError( + f"The parity of inputs and output do not match. \n \ + {mir_in1.ir.p} * {mir_in2.ir.p} should equal to {mir_out.ir.p}.") + if not (abs(mir_in1.ir.l - mir_in2.ir.l) <= mir_out.ir.l and mir_out.ir.l <= mir_in1.ir.l + mir_in2.ir.l): + raise ValueError( + f"The degree of inputs and output do not match. \n \ + The degrees should be |{mir_in1.ir.l} - {mir_in2.ir.l}| <= {mir_out.ir.l} <= |{mir_in1.ir.l} + {mir_in2.ir.l}|.") + if not raw_ins[3] in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']: + raise ValueError( + f"The connection mode should be in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']") + + +def _mode_check(mul_in1, mul_in2, mul_out, ins): + """Consistency check for multiplicities.""" + if ins['mode'] == 'uvw': + if not ins['has_weight']: + raise ValueError(f"The connection mode 'uvw' should have weights.") + elif ins['mode'] == 'uuu': + if not (mul_in1 == mul_in2 and mul_in2 == mul_out): + raise ValueError( + f"The multiplicity of inputs and output do not match. \ + It should be {mul_in1} == {mul_in2} == {mul_out}.") + elif ins['mode'] == 'uuw': + if not mul_in1 == mul_in2: + raise ValueError( + f"The multiplicity of inputs do not match. \ + It should be {mul_in1} == {mul_in2}.") + if not (ins['has_weight'] or mul_out == 1): + raise ValueError( + f"The multiplicity of input or 'has_weight' do not match. \ + If 'has_weight' == Flase, {mul_out} should equal to 1.") + elif ins['mode'] == 'uvu': + if not mul_in1 == mul_out: + raise ValueError( + f"The multiplicity of input 1 and output do not match. \ + It should be {mul_in1} == {mul_out}.") + elif ins['mode'] == 'uvv': + if not mul_in2 == mul_out: + raise ValueError( + f"The multiplicity of input 2 and output do not match. \ + It should be {mul_in2} == {mul_out}.") + elif ins['mode'] == 'uvuv': + if not mul_in1 * mul_in2 == mul_out: + raise ValueError( + f"The multiplicity of inputs and output do not match. \ + It should be {mul_in1} * {mul_in2} == {mul_out}.") + + +def _init_einsum(mode, ls): + """tensor graph contractions""" + if mode == 'uuu': + einsum = ops.Einsum("ijk,zui,zuj->zuk") + elif mode == 'uuw': + einsum = ops.Einsum("ijk,zui,zuj->zk") + elif mode == 'uvu': + einsum = ops.Einsum("ijk,zui,zvj->zuk") + elif mode == 'uvv': + einsum = ops.Einsum("ijk,zui,zvj->zvk") + elif mode == 'uvuv': + einsum = ops.Einsum("ijk,zui,zvj->zuvk") + return einsum + + +def _init_einsum_weight(mode, weight_mode, ls): + """tensor graph contractions with weights""" + z = "z" if weight_mode == 'custom' else "" + if mode == 'uvw': + einsum = ops.Einsum(f"ijk,zui,zvj,{z}uvw->zwk") + elif mode == 'uuu': + einsum = ops.Einsum(f"ijk,zui,zuj,{z}u->zuk") + elif mode == 'uuw': + einsum = ops.Einsum(f"ijk,zui,zuj,{z}uw->zwk") + elif mode == 'uvu': + einsum = ops.Einsum(f"ijk,zui,zvj,{z}uv->zuk") + elif mode == 'uvv': + einsum = ops.Einsum(f"ijk,zui,zvj,{z}uv->zvk") + elif mode == 'uvuv': + einsum = ops.Einsum(f"ijk,zui,zvj,{z}uv->zuvk") + return einsum + + +def _init_ncon(mode, ls): + """tensor graph contractions""" + if mode == 'uuu': + con_list = [[1, 2, -3], [-1, -2, 1], [-1, -2, 2]] + elif mode == 'uuw': + con_list = [[1, 2, -2], [-1, 3, 1], [-1, 3, 2]] + elif mode == 'uvu': + con_list = [[1, 2, -3], [-1, -2, 1], [-1, 3, 2]] + elif mode == 'uvv': + con_list = [[1, 2, -3], [-1, 3, 1], [-1, -2, 2]] + elif mode == 'uvuv': + con_list = [[1, 2, -4], [-1, -2, 1], [-1, -3, 2]] + ncon = Ncon(con_list) + return ncon + + +class uvw_ncon_v2(nn.Cell): + def __init__(self): + super(uvw_ncon_v2, self).__init__() + self.tensordot1 = tensordot + self.tensordot2 = tensordot + self.tensordot3 = vmap(tensordot, (0,0,None), 0) + def construct(self, m1, m2, m3, m4): + temp1 = self.tensordot1(m3, m1 , [2,1]) + temp2 = self.tensordot1(m2, m4 , [1,0]) + res = self.tensordot3(temp2, temp1, ([0,1],[1,0])) + return res + +def _init_ncon_weight(mode, weight_mode, ls): + """tensor graph contractions with weights""" + if mode == 'uvw': + con_list = [[1, 2, -3], [-1, 3, 1], [-1, 4, 2], [3, 4, -2]] + elif mode == 'uuu': + con_list = [[1, 2, -3], [-1, -2, 1], [-1, -2, 2], [-2]] + elif mode == 'uuw': + con_list = [[1, 2, -3], [-1, 3, 1], [-1, 3, 2], [3, -2]] + elif mode == 'uvu': + con_list = [[1, 2, -3], [-1, -2, 1], [-1, 3, 2], [-2, 3]] + elif mode == 'uvv': + con_list = [[1, 2, -3], [-1, 3, 1], [-1, -2, 2], [3, -2]] + elif mode == 'uvuv': + con_list = [[1, 2, -4], [-1, -2, 1], [-1, -3, 2], [-2, -3]] + if weight_mode == 'custom': + con_list[3] = [-1] + con_list[3] + ncon = Ncon(con_list) + return ncon + + +def _run_continue(ir1_data, ir2_data, irout_data, ins): + """check trivial computations""" + mir_in1 = ir1_data[ins['indice_one']] + mir_in2 = ir2_data[ins['indice_two']] + mir_out = irout_data[ins['i_out']] + if mir_in1.dim == 0 or mir_in2.dim == 0 or mir_out.dim == 0: + return True + return False + + +class TensorProduct(nn.Cell): + r""" + Versatile tensor product operator of two input `Irreps` and a output `Irreps`, that sends two tensors into a tensor + and keep the geometric tensor properties. + This class integrates different typical usages: `TensorSquare`, `FullTensorProduct`, `FullyConnectedTensorProduct`, + `ElementwiseTensorProduct` and `Linear`. + + A `TensorProduct` class defines an algebraic structure with equivariance. + Ones the `TensorProduct` object is created and initialized, the algorithm is determined. For any given two legal input + tensors, this object will provide a output tensor. + If the object do not have learnable weights, the output tensor is deterministic. + When the learnable weights are introduced, this operator will correspond to a general bilinear, equivariant operation, + as a generalization of the standard tensor product. + + If `irreps_in2` is not specified, it will be assigned as `irreps_in1`, corresponding to `TensorSquare`. + If `irreps_out` is not specified, this operator will account all possible output irreps. + If both `irreps_out` and `instructions` are not specified, this operator is the standard tensor product without + any learnable weights, corresponding to ``FullTensorProduct``. + + Each output irrep should satisfy: + + .. math:: + \| l_1 - l_2 \| \leq l_{out} \leq \| l_1 + l_2 \| + p_1 p_2 = p_{out} + + Args: + irreps_in1 (Union[str, Irrep, Irreps]): Irreps for the first input. + irreps_in2 (Union[str, Irrep, Irreps, None]): Irreps for the second input. Default: ``None``. + If `irreps_in2` is None, `irreps_in2` will be assigned as '0e' in 'linear' instructions, or be assigned as `irreps_in1` in otherwise, corresponding to `TensorSquare`. + irreps_out (Union[str, Irrep, Irreps, None]): Irreps for the output in 'connect' and custom instructions, or filter irreps for the output in otherwise. + If `irreps_out` is None, `irreps_out` will be the full tensor product irreps (including all possible paths). Default: ``None``. + instructions (Union[str, List[Tuple[int, int, int, str, bool, (float)]]]): List of tensor product path instructions. Default: ``'full'``. + For `str` in {'full', 'connect', 'element', 'linear', 'mearge'}, the instructions are constructed automatically according to the different modes: + + - 'full': each output irrep for every pair of input irreps — is created and returned independently. The outputs are not mixed with each other. + Corresponding to the standard tensor product `FullTensorProduct` if `irreps_out` is not specified. + - 'connect': each output is a learned weighted sum of compatible paths. This allows the operator to produce outputs with any multiplicity. + Corresponding to `FullyConnectedTensorProduct`. + - 'element': the irreps are multiplied one-by-one. The inputs will be split and that the multiplicities of the outputs match with the multiplicities of the input. + Corresponding to `ElementwiseTensorProduct`. + - 'linear': linear operation equivariant on the first irreps, while the second irreps is set to be '0e'. This can be regarded as the geometric tensors version of teh dense layer. + Corresponding to `Linear`. + - 'merge': Automatically build 'uvu' mode instructions with trainable parameters. The `irreps_out` here plays the role of output filters. + + For `List[Tuple[int, int, int, str, bool, (float)]]`, the instructions are constructed manually. + + Each instruction contain a tuple: (indice_one, indice_two, i_out, mode, has_weight, (optional: path_weight)). + Each instruction puts ``in1[indice_one]`` :math:`\otimes` ``in2[indice_two]`` into ``out[i_out]``. + + - `indice_one`, `indice_two`, `i_out`: int, the index of the irrep in irreps for `irreps_in1`, `irreps_in2` and `irreps_out` correspondingly. + - `mode`: str in {'uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv'}, the way of the multiplicities of each path are treated. 'uvw' is the fully mixed mode. + - `has_weight`: bool, `True` if this path should have learnable weights, otherwise `False`. + - `path_weight`:float, a multiplicative weight to apply to the output of this path. Defaults: 1.0. + + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. Default: ``'component'``. + + - 'norm': :math:`\| x \| = \| y \| = 1 \Longrightarrow \| x \otimes y \| = 1` + + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: ``'element'``. + + - 'element': each output is normalized by the total number of elements (independently of their paths). + - 'path': each path is normalized by the total number of elements in the path, then each output is normalized by the number of paths. + + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', 'xavier_uniform'}, the initial method of weights. Default: ``'normal'``. + weight_mode (str): {'inner', 'share', 'custom'} determine the weights' mode. Default: ``'inner'``. + + - 'inner': weights will initialized in the tensor product internally. + - 'share': weights should given manually without batch dimension. + - 'custom': weights should given manually with batch dimension. + + core_mode (str): {'ncon', 'einsum'} determine the core computation mode. Default: ``'ncon'``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Inputs: + - **x** (Tensor) - The shape of Tensor is ``(..., irreps_in1.dim)`` + - **y** (Tensor) - The shape of Tensor is ``(..., irreps_in2.dim)`` + - **weight** (Tensor) - `Tensor` or list of `Tensor`, optional + required if ``internal_weights`` is ``False``. + The shape of Tensor is ``(self.weight_numel,)`` if ``shared_weights`` is ``True``. + The shape of Tensor is ``(..., self.weight_numel)`` if ``shared_weights`` is ``False`` + or list of tensors of shapes ``weight_shape`` / ``(...) + weight_shape``. + Use ``self.instructions`` to know what are the weights used for. + The shape of Tensor is ``(..., irreps_out.dim)``. + + Outputs: + - **outputs** (Tensor) - The shape of Tensor is ``(..., irreps_out.dim)``. + + Raises: + ValueError: If `irreps_out` is not legal. + ValueError: If the connection mode is not in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']. + ValueError: If the degree of inputs and output do not match. + ValueError: If the parity of inputs and output do not match. + ValueError: If the multiplicity of inputs and output do not match. + ValueError: If the connection mode is 'uvw', but `has_weight` is `False`. + ValueError: If the connection mode is 'uuw' and `has_weight` is `False`, but the multiplicity is not equal to 1. + ValueError: If the initial method is not supported. + ValueError: If the number of input tensors is not match to the number of input irreps. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> from mindchemistry.e3.o3 import TensorProduct + Standard tensor product: + >>> tp1 = TensorProduct('2x1o+4x0o', '1x1o+3x0e') + TensorProduct [full] (2x1o+4x0o x 1x1o+3x0e -> 2x0e+12x0o+6x1o+2x1e+4x1e+2x2e) + >>> v1 = ms.Tensor(np.linspace(1., 2., tp1.irreps_in1.dim), dtype=ms.float32) + >>> v2 = ms.Tensor(np.linspace(2., 3., tp1.irreps_in2.dim), dtype=ms.float32) + >>> tp1(v1, v2).shape + (1, 60) + Elementwise tensor product: + >>> tp2 = TensorProduct('2x2e+4x1o', '3x1e+3x0o') + TensorProduct [element] (2x2e+1x1o+3x1o x 2x1e+1x1e+3x0o -> 2x1e+2x2e+2x3e+1x0o+1x1o+1x2o+3x1e) + >>> tp2.instructions + [(0, 0, 0, 'uuu', False), (0, 0, 1, 'uuu', False), (0, 0, 2, 'uuu', False), (1, 1, 3, 'uuu', False), + (1, 1, 4, 'uuu', False), (1, 1, 5, 'uuu', False), (2, 2, 6, 'uuu', False)] + Custom tensor product with learnable weights: + >>> tp3 = TensorProduct( + ... '3x2o+2x1o', '2x2e+4x1o+5x0e', '2x3o+8x1e+10x1o', + ... [ + ... (0,0,0,'uvv',True), + ... (1,0,0,'uuu',True), + ... (1,1,1,'uvuv',True), + ... (1,2,2,'uvw',True) + ... ] + ... ) + TensorProduct [custom] (3x2o+2x1o x 2x2e+4x1o+5x0e -> 2x3o+8x1e+10x1o) + >>> [w.shape for w in tp3.weights] + [(3, 2), (2,), (2, 4), (2, 5, 10)] + Linear operation with an output filter: + >>> tp4 = TensorProduct('2x1o', irreps_out='5x2e+4x1e+7x1o', instructions='connect') + TensorProduct [linear] (2x2e+3x1o+3x0e x 1x0e -> 3x2e+5x1o+2x0e) + >>> v1 = ms.Tensor(np.linspace(1., 2., tp4.irreps_in1.dim), dtype=ms.float32) + >>> tp4(v1).shape + (1, 32) + """ + __slots__ = ('irreps_in1', 'irreps_in2', 'irreps_out', + 'weights', '_in2_is_none', '_mode', '_device', 'output_mask', 'core_mode') + + def __init__( + self, + irreps_in1, + irreps_in2=None, + irreps_out=None, + instructions='full', + dtype=float32, + irrep_norm='component', + path_norm='element', + weight_init='normal', + weight_mode='inner', + core_mode='ncon', + ncon_dtype = float32 + ): + super().__init__() + + if weight_mode not in ['inner', 'share', 'custom']: + raise ValueError( + f"`weight_mode` should be one of ['inner', 'share', 'custom'].") + if core_mode not in ['ncon', 'einsum']: + raise ValueError( + f"`core_mode` should be one of ['ncon', 'einsum'].") + elif core_mode == 'einsum' and get_context('device_target') != 'GPU': + raise ValueError( + f"The `core_mode`: einsum only support GPU, but got {get_context('device_target')}.") + self.weight_mode = weight_mode + self.dtype = dtype + self.core_mode = core_mode + self.ones = ops.Ones() + self.zeros = ops.Zeros() + + self.irreps_in1 = Irreps(irreps_in1).simplify() + if irreps_in2 is None: + self.irreps_in2 = Irreps(irreps_in1).simplify() + self._in2_is_none = True + else: + self.irreps_in2 = Irreps(irreps_in2).simplify() + self._in2_is_none = False + + self.irreps_out, instructions = self._input_init( + self.irreps_in1, self.irreps_in2, irreps_out, instructions) + + self.instr, self._ncons = self._ins_init(instructions) + + self.weight_numel = sum(_prod(ins['path_shape']) + for ins in self.instr if ins['has_weight']) + + self.weights = self._weight_init(weight_init) + + self.output_mask = self._init_mask() + + self._normalization(irrep_norm=irrep_norm, path_norm=path_norm) + + self.ncon_dtype = ncon_dtype + + def construct(self, v1, v2=None, weight=None): + """Implement tensor product for input tensors.""" + self._weight_check(weight) + + if self._in2_is_none: + if v2 is not None: + raise ValueError(f"This tensor product should input 1 tensor.") + + if self._mode == 'linear': + v2_shape = v1.shape[:-1] + (1,) + v2 = self.ones(v2_shape, v1.dtype) + else: + v2 = v1.copy() + else: + if v2 is None: + raise ValueError( + f"This tensor product should input 2 tensors.") + if self._mode == 'linear': + v2_shape = v1.shape[:-1] + (1,) + v2 = self.ones(v2_shape, v1.dtype) + + batch_shape = v1.shape[:-1] + v1s = self.irreps_in1.decompose(v1, batch=True) + v2s = self.irreps_in2.decompose(v2, batch=True) + weight = self._get_weights(weight) + if not (v1.shape[-1] == self.irreps_in1.dim and v2.shape[-1] == self.irreps_in2.dim): + raise ValueError(f"The shape of input tensors do not match.") + + v3_list = [] + weight_ind = 0 + fn = 0 + + for ins in self.instr: + if _run_continue(self.irreps_in1.data, self.irreps_in2.data, self.irreps_out.data, ins): + continue + fn = self._ncons[ins['i_ncon']] + if ins['has_weight']: + l = _prod(ins['path_shape']) + w = narrow(weight, -1, weight_ind, l).reshape(((-1,) + if self.weight_mode == 'custom' else ()) + ins['path_shape']).astype(self.ncon_dtype) + weight_ind += l + if self.core_mode == 'einsum': + v3 = fn((ins['wigner_matrix'].astype(self.ncon_dtype), v1s[ins['indice_one']].astype(self.ncon_dtype), v2s[ins['indice_two']].astype(self.ncon_dtype), w)) + else: + v3 = fn([ins['wigner_matrix'].astype(self.ncon_dtype), v1s[ins['indice_one']].astype(self.ncon_dtype), v2s[ins['indice_two']].astype(self.ncon_dtype), w]) + else: + if self.core_mode == 'einsum': + v3 = fn((ins['wigner_matrix'].astype(self.ncon_dtype), v1s[ins['indice_one']].astype(self.ncon_dtype), v2s[ins['indice_two']].astype(self.ncon_dtype))) + else: + v3 = fn([ins['wigner_matrix'].astype(self.ncon_dtype), v1s[ins['indice_one']].astype(self.ncon_dtype), v2s[ins['indice_two']].astype(self.ncon_dtype)]) + v3_list.append(ins['path_weight'].astype(self.dtype) * v3.astype(self.dtype)) + + v_out = _compose(v3_list, self.irreps_out.data, self.instr, batch_shape) + return v_out + + def __repr__(self): + return f'TensorProduct [{self._mode}] ({self.irreps_in1.simplify().__repr__()} x {self.irreps_in2.simplify().__repr__()} -> {self.irreps_out.simplify().__repr__()} | {self.weight_numel} weights)' + + @property + def instructions(self): + return [tuple(ins.values())[:5] for ins in self.instr] + + def _input_init(self, irreps_in1, irreps_in2, irreps_out, instructions): + if not isinstance(instructions, str): + irreps_out = irreps_in1 * \ + irreps_in2 if irreps_out is None else Irreps(irreps_out) + self._mode = 'custom' + else: + if instructions == 'connect': + irreps_out, instructions = _connect_init( + irreps_in1, irreps_in2, irreps_out) + self._mode = 'connect' + + elif instructions == 'full': + irreps_out, instructions = _full_init( + irreps_in1, irreps_in2, irreps_out) + self._mode = 'full' + + elif instructions == 'element': + self.irreps_in1, self.irreps_in2, irreps_out, instructions = _element_init( + irreps_in1, irreps_in2, irreps_out) + self._mode = 'element' + + elif instructions == 'linear': + self.irreps_in2 = Irreps('0e') + irreps_out, instructions = _linear_init(irreps_in1, irreps_out) + self._mode = 'linear' + + elif instructions == 'merge': + irreps_out, instructions = _merge_init( + irreps_in1, irreps_in2, irreps_out) + self._mode = 'merge' + + else: + raise ValueError( + f"Unexpected instructions mode {instructions}") + + return irreps_out, instructions + + def _ins_init(self, raw_ins): + """reform instructions""" + raw_ins = [x if len(x) == 6 else x + (1.0,) for x in raw_ins] + res = [] + ncons = [] + + for ins in raw_ins: + indice_one = ins[0] + indice_two = ins[1] + i_out = ins[2] + mode = ins[3] + has_weight = ins[4] + path_weight = ins[5] + + mirs = ( + self.irreps_in1.data[indice_one], self.irreps_in2.data[indice_two], self.irreps_out.data[i_out]) + muls = (mirs[0].mul, mirs[1].mul, mirs[2].mul) + + _raw_ins_check(*mirs, ins) + + path_shape = { + 'uvw': (muls[0], muls[1], muls[2]), + 'uvu': (muls[0], muls[1]), + 'uvv': (muls[0], muls[1]), + 'uuw': (muls[0], muls[2]), + 'uuu': (muls[0],), + 'uvuv': (muls[0], muls[1]), + }[mode] + + num_elements = { + 'uvw': (muls[0] * muls[1]), + 'uvu': muls[1], + 'uvv': muls[0], + 'uuw': muls[0], + 'uuu': 1, + 'uvuv': 1, + }[mode] + + ls = (mirs[0].ir.l, mirs[1].ir.l, mirs[2].ir.l) + + d, op = self._ins_dict(indice_one, indice_two, i_out, mode, has_weight, + path_weight, path_shape, num_elements, wigner_3j(*ls, self.dtype), ls) + ncons.append(op) + d['i_ncon'] = len(ncons) - 1 + res.append(d) + + _mode_check(*muls, res[-1]) + + return res, ncons + + def _ins_dict(self, *args): + """generate reformed instructions""" + d = {} + keys = ['indice_one', 'indice_two', 'i_out', 'mode', 'has_weight', + 'path_weight', 'path_shape', 'num_elements', 'wigner_matrix', 'ls'] + for i, arg in enumerate(args): + d[keys[i]] = arg + + if d['has_weight']: + if self.core_mode == 'einsum': + operator = _init_einsum_weight( + d['mode'], self.weight_mode, d['ls']) + else: + operator = _init_ncon_weight( + d['mode'], self.weight_mode, d['ls']) + else: + if self.core_mode == 'einsum': + operator = _init_einsum(d['mode'], d['ls']) + else: + operator = _init_ncon(d['mode'], d['ls']) + + return d, operator + + def _weight_init(self, init_method): + """init weights""" + init_method = renormal_initializer(init_method) + + if self.weight_numel > 0 and self.weight_mode == 'inner': + weights = Parameter(initializer(init_method, (1, self.weight_numel), dtype=self.dtype).init_data().flatten()) + else: + weights = None + + return weights + + def _init_mask(self): + if self.irreps_out.dim > 0: + output_mask = ops.cat([ + self.ones(mul * ir.dim, int32) + if any( + (ins['i_out'] == i_out) and (ins['path_weight'] + != 0) and (0 not in ins['path_shape']) + for ins in self.instr + ) + else self.zeros(mul * ir.dim, int32) + for i_out, (mul, ir) in enumerate(self.irreps_out.data) + ]) + else: + output_mask = Tensor(0) + + return output_mask + + def _normalization(self, irrep_norm, path_norm): + """path normalization""" + for ins in self.instr: + mir_in1 = self.irreps_in1.data[ins['indice_one']] + mir_in2 = self.irreps_in2.data[ins['indice_two']] + mir_out = self.irreps_out.data[ins['i_out']] + + alpha = 1. + if irrep_norm == 'component': + alpha = mir_out.ir.dim + if irrep_norm == 'norm': + alpha = mir_in1.ir.dim * mir_in2.ir.dim + + x = 1. + if path_norm == 'element': + x = sum(i['num_elements'] + for i in self.instr if i['i_out'] == ins['i_out']) + if path_norm == 'path': + x = ins['num_elements'] + x *= len([i for i in self.instr if i['i_out'] + == ins['i_out']]) + + if x > 0.0: + alpha /= x + + alpha *= ins['path_weight'] + ins['path_weight'] = _sqrt(alpha, self.dtype) + + def _weight_check(self, weight): + if self.weight_mode == 'inner': + if weight is None: + return True + raise ValueError( + f"For `weight_mode` {self.weight_mode}, the `weight` should not given manually.") + elif self.weight_mode == 'share': + if weight is None: + raise ValueError( + f"For `weight_mode` {self.weight_mode}, the `weight` should given manually.") + if not weight.ndim == 1: + raise ValueError( + f"The shape of custom weight {weight.shape} is illegal.") + elif self.weight_mode == 'custom': + if weight is None: + raise ValueError( + f"For `weight_mode` {self.weight_mode}, the `weight` should given manually.") + if not weight.ndim > 1: + raise ValueError( + f"Custom weight {weight} should have batch dimension if `weight_mode` is `'custom'`.") + else: + raise ValueError(f"Unknown `weight_mode`: {self.weight_mode}.") + return True + + def _get_weights(self, weight): + if weight is None: + return self.weights + else: + return weight.reshape(-1, self.weight_numel) diff --git a/mindscience/e3nn/o3/wigner.py b/mindscience/e3nn/o3/wigner.py new file mode 100644 index 0000000000000000000000000000000000000000..bd086be33a64b84ee9bc8adb500f69f7f830eaab --- /dev/null +++ b/mindscience/e3nn/o3/wigner.py @@ -0,0 +1,336 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import functools +import math +from fractions import Fraction +from math import factorial + +import numpy as np + +from mindspore import Tensor, ops, float32, float64, complex64, complex128 + +from ..utils.func import _ndexpm, broadcast_args, _expand_last_dims + +PI = Tensor(math.pi) + + +def change_basis_real_to_complex(l, dtype=float32): + r""" + Convert a real basis of spherical harmonics in term of complex. + + Args: + l (int): degree of spherical harmonics. + dtype (dtype): {float32, float64} data type of the real basis. Default: float32. + + Returns: + Tensor, the complex basis with dtype complex64 for `dtype` = float32 and complex128 for `dtype` = float64. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import change_basis_real_to_complex + >>> m = change_basis_real_to_complex(1) + >>> print(m) + [[-0.70710677+0.j 0. +0.j 0. -0.70710677j] + [ 0. +0.j 0. -1.j 0. +0.j ] + [-0.70710677+0.j 0. +0.j 0. +0.70710677j]] + """ + q = np.zeros((2 * l + 1, 2 * l + 1), np.complex128) + for m in range(-l, 0): + q[l + m, l + abs(m)] = 1 / 2 ** 0.5 + q[l + m, l - abs(m)] = -1j / 2 ** 0.5 + q[l, l] = 1 + for m in range(1, l + 1): + q[l + m, l + abs(m)] = (-1) ** m / 2 ** 0.5 + q[l + m, l - abs(m)] = 1j * (-1) ** m / 2 ** 0.5 + q = (-1j) ** l * q + + dtype = { + float32: complex64, + float64: complex128, + }[dtype] + + q_new = Tensor(q, dtype=dtype) + return q_new + + +def su2_generators(j, dtype=complex64): + r""" + Compute the su(2) Lie algebra generators. + + Args: + j (int): degree of generators. + dtype (dtype): {complex64, complex128} data type of generators. Default: complex64. + + Returns: + Tensor, su(2) generators with the dtype is `dtype`. + + Raise: + TypeError: If `j` is not int. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import su2_generators + >>> m = su2_generators(1) + >>> print(m) + [[[ 0. +0.j 0.70710677+0.j + 0. +0.j ] + [-0.70710677+0.j 0. +0.j + 0.70710677+0.j ] + [ 0. +0.j -0.70710677+0.j + 0. +0.j ]] + [[-0. -1.j 0. +0.j + 0. +0.j ] + [ 0. +0.j 0. +0.j + 0. +0.j ] + [ 0. +0.j 0. +0.j + 0. +1.j ]] + [[ 0. -0.j 0. +0.70710677j + 0. -0.j ] + [ 0. +0.70710677j 0. -0.j + 0. +0.70710677j] + [ 0. -0.j 0. +0.70710677j + 0. -0.j ]]] + """ + if not isinstance(j, int): + raise TypeError + m = np.arange(-j, j) + raising = np.diag(-np.sqrt(j * (j + 1) - m * (m + 1)), k=-1) + + m = np.arange(-j + 1, j + 1) + lowering = np.diag(np.sqrt(j * (j + 1) - m * (m - 1)), k=1) + + m = np.arange(-j, j + 1) + res = np.stack([ + 0.5 * (raising + lowering), # x (usually) + np.diag(1j * m), # z (usually) + -0.5j * (raising - lowering), # -y (usually) + ], axis=0) + return Tensor(res, dtype=dtype) + + +def so3_generators(l, dtype=float32): + r""" + Compute the so(3) Lie algebra generators. + + Args: + l (int): degree of generators. + dtype (dtype): {float32, float64} data type of generators. Default: float32. + + Returns: + Tensor, so(3) generators with the dtype is `dtype`. + + Raise: + TypeError: If `l` is not int. + ValueError: If matrices data are inconsistent. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import so3_generators + >>> m = so3_generators(1) + >>> print(m) + [[[ 0. 0. 0. ] + [ 0. 0. -0.99999994] + [ 0. 0.99999994 0. ]] + [[ 0. 0. 0.99999994] + [ 0. 0. 0. ] + [-0.99999994 0. 0. ]] + [[ 0. -0.99999994 0. ] + [ 0.99999994 0. 0. ] + [ 0. 0. 0. ]]] + """ + if not isinstance(l, int): + raise TypeError + cdtype = { + float32: complex64, + float64: complex128, + }[dtype] + X = su2_generators(l, dtype=cdtype).asnumpy() + Q = change_basis_real_to_complex(l, dtype=dtype).asnumpy() + X = np.conj(Q.T) @ X @ Q + + if not np.all(np.abs(np.imag(X)) < 1e-5): + raise ValueError + X_real = np.real(X) + return Tensor(X_real, dtype=dtype) + + +def wigner_D(l, alpha, beta, gamma): + r""" + Wigner D matrix representation of SO(3). + + It satisfies the following properties: + * :math:`D(\text{identity rotation}) = \text{identity matrix}` + * :math:`D(R_1 \circ R_2) = D(R_1) \circ D(R_2)` + * :math:`D(R^{-1}) = D(R)^{-1} = D(R)^T` + + Args: + l (int): degree of representation. + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second. + gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first. + + Returns: + Tensor, Wigner D matrix :math:`D^l(\alpha, \beta, \gamma)`. The shape of Tensor is :math:`(2l+1, 2l+1)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import wigner_D + >>> m = wigner_D(1,1,1,1) + >>> print(m) + [[-0.09064701 0.7080733 0.70029646] + [ 0.7080733 0.54030234 -0.45464867] + [-0.7002964 0.45464864 -0.5503447 ]] + + """ + + alpha, beta, gamma = broadcast_args(alpha, beta, gamma) + alpha = _expand_last_dims(alpha) % (2 * PI) + beta = _expand_last_dims(beta) % (2 * PI) + gamma = _expand_last_dims(gamma) % (2 * PI) + X = so3_generators(l) + return ops.matmul(ops.matmul(_ndexpm(alpha * X[1]), _ndexpm(beta * X[0])), _ndexpm(gamma * X[1])) + + +def wigner_3j(l1, l2, l3, dtype=float32): + r""" + Wigner 3j symbols :math:`C_{lmn}`. + + It satisfies the following two properties: + + .. math:: + C_{lmn} = C_{ijk} D_{il}(g) D_{jm}(g) D_{kn}(g) \qquad \forall g \in SO(3) + + where :math:`D` are given by `wigner_D`. + + .. math:: + C_{ijk} C_{ijk} = 1 + + Args: + l1 (int): :math:`l_1` parameter of ``wigner_3j``. + l2 (int): :math:`l_2` parameter of ``wigner_3j``. + l3 (int): :math:`l_3` parameter of ``wigner_3j``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + + Returns: + Tensor, Wigner 3j symbols :math:`C_{lmn}`. The shape of Tensor is :math:`(2l_1+1, 2l_2+1, 2l_3+1)`. + + Raise: + TypeError: If `l1`, `l2` or `l3` are not int. + ValueError: If `l1`, `l2` and `l3` do not satisfy abs(l2 - l3) <= l1 <= l2 + l3. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import wigner_3j + >>> m = wigner_3j(1,1,1) + >>> print(m) + [[[ 0. 0. 0. ] + [ 0. 0. 0.4082483] + [ 0. -0.4082483 0. ]] + [[ 0. 0. -0.4082483] + [ 0. 0. 0. ] + [ 0.4082483 0. 0. ]] + [[ 0. 0.4082483 0. ] + [-0.4082483 0. 0. ] + [ 0. 0. 0. ]]] + """ + if not isinstance(l1, int) and isinstance(l2, int) and isinstance(l3, int): + raise TypeError + if not abs(l2 - l3) <= l1 and l1 <= l2 + l3: + raise ValueError( + f"The inputs degree \"{l1}\" and \"{l2}\" do not match to output degree \"{l3}\". \nThe degrees should be |{l1} - {l2}| <= {l3} <= |{l1} + {l2}|.") + C = _so3_clebsch_gordan(l1, l2, l3) + + return Tensor(C, dtype=dtype) + + +@functools.lru_cache(maxsize=None) +def _so3_clebsch_gordan(l1, l2, l3, dtype=float64): + """Calculates the Clebsch-Gordon matrix for SO(3) coupling l1 and l2 to give l3.""" + Q1 = change_basis_real_to_complex(l1, dtype=dtype).asnumpy() + Q2 = change_basis_real_to_complex(l2, dtype=dtype).asnumpy() + Q3 = change_basis_real_to_complex(l3, dtype=dtype).asnumpy() + C = _su2_clebsch_gordan(l1, l2, l3) + + C = np.einsum('ij,kl,mn,ikn->jlm', Q1, Q2, np.conj(Q3.T), C) + + if not np.all(np.abs(np.imag(C)) < 1e-5): + raise ValueError + C = np.real(C) + + C = C / np.linalg.norm(C) + return C + + +@functools.lru_cache(maxsize=None) +def _su2_clebsch_gordan(j1, j2, j3): + """Calculates the Clebsch-Gordon matrix for SU(2) coupling j1 and j2 to give j3.""" + if not (isinstance(j1, (int, float)) and isinstance(j2, (int, float)) and isinstance(j3, (int, float))): + raise TypeError + mat = np.zeros((int(2 * j1 + 1), int(2 * j2 + 1), + int(2 * j3 + 1)), np.float64) + if int(2 * j3) in range(int(2 * abs(j1 - j2)), int(2 * (j1 + j2)) + 1, 2): + for m1 in (x / 2 for x in range(-int(2 * j1), int(2 * j1) + 1, 2)): + for m2 in (x / 2 for x in range(-int(2 * j2), int(2 * j2) + 1, 2)): + if abs(m1 + m2) <= j3: + mat[int(j1 + m1), int(j2 + m2), int(j3 + m1 + m2) + ] = _su2_clebsch_gordan_coeff((j1, m1), (j2, m2), (j3, m1 + m2)) + + return mat + + +def _su2_clebsch_gordan_coeff(idx1, idx2, idx3): + """core function of the Clebsch-Gordon coefficient for SU(2) coupling (j1,m1) and (j2,m2) to give (j3,m3).""" + + j1, m1 = idx1 + j2, m2 = idx2 + j3, m3 = idx3 + + if m3 != m1 + m2: + return 0 + vmin = int(max([-j1 + j2 + m3, -j1 + m1, 0])) + vmax = int(min([j2 + j3 + m1, j3 - j1 + j2, j3 + m3])) + + def f(n): + if not n == round(n): + raise ValueError + return factorial(round(n)) + + C = ( + (2.0 * j3 + 1.0) * Fraction( + f(j3 + j1 - j2) * f(j3 - j1 + j2) * + f(j1 + j2 - j3) * f(j3 + m3) * f(j3 - m3), + f(j1 + j2 + j3 + 1) * f(j1 - m1) * + f(j1 + m1) * f(j2 - m2) * f(j2 + m2) + ) + ) ** 0.5 + + S = 0 + for v in range(vmin, vmax + 1): + S += (-1) ** int(v + j2 + m2) * Fraction( + f(j2 + j3 + m1 - v) * f(j1 - m1 + v), + f(v) * f(j3 - j1 + j2 - v) * f(j3 + m3 - v) * f(v + j1 - j2 - m3) + ) + C = C * S + return C diff --git a/tests/e3nn/__init__.py b/mindscience/e3nn/utils/__init__.py similarity index 70% rename from tests/e3nn/__init__.py rename to mindscience/e3nn/utils/__init__.py index 83b15297dbb5877a8f4175bfffa35a331ca2a156..2161cda3e7f0cbf4c661bca3b0440693d8329b63 100644 --- a/tests/e3nn/__init__.py +++ b/mindscience/e3nn/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +"""init""" +from .ncon import Ncon +from .radius import radius, radius_graph, radius_full, radius_graph_full + + +__all__ = [ + "Ncon", + "radius", + "radius_graph", + "radius_full", + "radius_graph_full", +] diff --git a/mindscience/e3nn/utils/batch_dot.py b/mindscience/e3nn/utils/batch_dot.py new file mode 100644 index 0000000000000000000000000000000000000000..06dfbd2cee8560828bb76714eb1d6c556336dd64 --- /dev/null +++ b/mindscience/e3nn/utils/batch_dot.py @@ -0,0 +1,165 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Batch dot product operations for tensor computations. + +This module provides utilities for performing batch-wise dot products +between tensors with support for various axis configurations. +""" +from mindspore.ops.primitive import constexpr +from mindspore.ops import functional as F +from mindspore.ops import operations as P + + +@constexpr +def _get_batch_size(x1_shape, x2_shape): + """ + Get batch sizes from two inputs + """ + return x1_shape[0], x2_shape[0] + + +@constexpr +def _calc_new_shape_batchdot(shape, axes, position=0): + """ + Calculate transpose and reshape parameters for input transformations, + 'position' refers to whether tensor is first or second in the op. + """ + axis = axes[position] + contraction_axes = tuple([axis]) + prod_contraction = 1 + for i in contraction_axes: + prod_contraction *= shape[i] + free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes) + free_dims = tuple(shape[i] for i in free_axes) + prod_free = 1 + for free_dim in free_dims: + prod_free *= free_dim + + transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes + transpose_perm = tuple([0]) + transpose_perm + new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction) + new_shape = tuple([shape[0]]) + new_shape + return new_shape, transpose_perm, free_dims + + +@constexpr +def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None): + """ + Check whether batch size of two inputs are the same + """ + msg_prefix = f"For '{prim_name}', the" if prim_name else "The" + if x1_batch_size != x2_batch_size: + raise ValueError(f"{msg_prefix} inputs 'x1', 'x2' should have the same batch sizes, but got " + f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.") + + +@constexpr +def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): + """ + Check whether axes are valid and cast axes from tuple to list + """ + if axes is None: + if len(x2_shape) == 2: + axes = [len(x1_shape) - 1, len(x2_shape) - 1] + else: + axes = [len(x1_shape) - 1, len(x2_shape) - 2] + + if isinstance(axes, (list, tuple)): + if isinstance(axes, tuple): + axes = list(axes) + # Reverse if axis < 0 + if axes[0] < 0: + axes[0] += len(x1_shape) + if axes[1] < 0: + axes[1] += len(x2_shape) + elif isinstance(axes, int): + if axes < 0: + axes = [axes + len(x1_shape), axes + len(x2_shape)] + else: + axes = [axes, axes] + return axes + + +@constexpr +def _get_output_shape(batch_size, x1_ret, x2_ret): + """ + Compute output shape for batch dot + """ + output_shape = tuple([batch_size]) + x1_ret + x2_ret + return output_shape + + +def batch_dot(x1, x2, axes=None): + """ + Compute batch-wise dot product of two tensors. + + Args: + x1 (Tensor): First input tensor with shape (batch_size, ...). + x2 (Tensor): Second input tensor with shape (batch_size, ...). + axes (int, list, tuple, optional): Axes to perform dot product along. + If None, defaults to the last axis of x1 and second-to-last axis of x2. + + Returns: + Tensor: The batch dot product result. + + Raises: + ValueError: If batch sizes of x1 and x2 don't match. + """ + transpose_op = P.Transpose() + batch_matmul_op = P.BatchMatMul() + squeeze_one_op = P.Squeeze(1) + squeeze_minus_one_op = P.Squeeze(-1) + # input validity checks + x1_shape = F.shape(x1) + x2_shape = F.shape(x2) + x1_dim_num = len(x1_shape) + x2_dim_num = len(x2_shape) + + x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape) + + _check_batch_size(x1_batch_size, x2_batch_size, 'batch_dot') + axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes) + + if x1_dim_num == 2: + x1 = F.expand_dims(x1, 1) + axes[0] += 1 + if x2_dim_num == 2: + x2 = F.expand_dims(x2, 2) + + x1_shape = F.shape(x1) + x2_shape = F.shape(x2) + + x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0) + x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1) + output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret) + + x1_transposed = transpose_op(x1, x1_transpose_fwd) + x2_transposed = transpose_op(x2, x2_transpose_fwd) + x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd) + x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd) + + # Batch matmal op part + mul_result = batch_matmul_op(x1_reshaped, x2_reshaped) + + final_result = F.reshape(mul_result, output_shape) + + # if the original dims are expanded, restore them from 3 to 2 + if x1_dim_num == 2: + final_result = squeeze_one_op(final_result) + elif x2_dim_num == 2: + final_result = squeeze_minus_one_op(final_result) + + return final_result diff --git a/mindscience/e3nn/utils/func.py b/mindscience/e3nn/utils/func.py new file mode 100644 index 0000000000000000000000000000000000000000..22bdf38ad6ef86c03e2489322d12264d6c3aab6b --- /dev/null +++ b/mindscience/e3nn/utils/func.py @@ -0,0 +1,166 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Utility functions for tensor operations and broadcasting. + +This module provides various utility functions for tensor manipulation, +broadcasting operations, and mathematical computations commonly used +in the e3nn library. +""" +import numpy as np +from scipy.linalg import expm + +from mindspore import Tensor, ops +from mindspore.ops import operations as P + + +def norm_keep(input_x, axis): + r""" + Compute the matrix norm or vector norm of a given tensor, and the output tensors have dimension retained. + + Args: + input_x (Tensor): Input tensor. The dtype must be float32 or float16. + axis (Union[int, list, tuple]): Specifies which dimension or dimensions of input to calculate the norm across. + + Returns: + Tensor, has the same dtype and shape as `input`. + """ + return ops.expand_dims(input_x.norm(None, axis, False), axis=axis) + + +def _to_tensor(arg): + if isinstance(arg, (int, float)): + return Tensor(arg) + if isinstance(arg, (np.ndarray, list, tuple)): + return Tensor(arg) + if isinstance(arg, Tensor): + return arg + raise TypeError + + +def broadcast_shapes(*shapes): + r""" + Return the broadcast shape of the shapes of input tensors. + + Args: + shapes (tuple): Any number of shapes of tensors to be broadcasted. + + Returns: + Tuple, a shape compatible with all input shapes. + """ + max_len = 0 + for shape in shapes: + if isinstance(shape, int): + if max_len < 1: + max_len = 1 + elif isinstance(shape, (list, tuple)): + s = len(shape) + if max_len < s: + max_len = s + result = [1] * max_len + for shape in shapes: + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, (list, tuple)): + for i in range(-1, -1 - len(shape), -1): + if shape[i] < 0: + raise RuntimeError("Trying to create tensor with negative dimension ({}): ({})" + .format(shape[i], shape[i])) + if shape[i] == 1 or shape[i] == result[i]: + continue + if result[i] != 1: + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape") + result[i] = shape[i] + else: + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape) + return tuple(result) + + +def broadcast_tensors(*tensors): + r""" + Broadcasts the given tensors. + + Args: + tensors (Tensor): Any number of tensors of the same type. + + Returns: + A list of tensors, tensors after broadcast. + """ + shapes = [] + for tensor in tensors: + shapes.append(tensor.shape) + shape = broadcast_shapes(*shapes) + res = [] + for tensor in tensors: + if shape: + res.append(ops.broadcast_to(tensor, shape)) + else: + res.append(tensor) + return res + + +def broadcast_args(*args): + r""" + Broadcasts the given data with multiple types. + + Args: + *arg (Union[Tensor[float32], list[float], tuple[float], + ndarray[np.float32], float]): Any number of data to be broadcasted. + + Returns: + A list of tensors, tensors after broadcast. + """ + tensors = [] + for arg in args: + tensors.append(_to_tensor(arg)) + res = broadcast_tensors(*tensors) + return res + + +def _ndexpm(mat): + """Compute matrix-product exponential of matrices.""" + if isinstance(mat, Tensor): + mat = mat.asnumpy() + mat_shape = mat.shape + if len(mat_shape) < 2: + raise ValueError + if len(mat_shape) == 2: + return Tensor(expm(mat)) + mat = np.reshape(mat, (-1, mat_shape[-1], mat_shape[-1])) + n = mat.shape[0] + for i in range(n): + mat[i] = expm(mat[i]) + mat = np.reshape(mat, mat_shape) + return Tensor(mat) + + +def _expand_last_dims(x): + if isinstance(x, Tensor): + x = ops.expand_dims(x, -1) + x = ops.expand_dims(x, -1) + else: + x = x[..., None, None] + return x + + +def narrow(inputs, axis, start, length): + """tmp narrow API""" + begins = [0] * inputs.ndim + begins[axis] = start + sizes = [i for i in inputs.shape] + sizes[axis] = length + return P.Slice()(inputs, begins, sizes) diff --git a/mindscience/e3nn/utils/initializer.py b/mindscience/e3nn/utils/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..755538ea54e631c749c4f8f96885ae6fc1789546 --- /dev/null +++ b/mindscience/e3nn/utils/initializer.py @@ -0,0 +1,85 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Custom initializers for neural network parameters. + +This module provides custom weight initialization methods including +uniform distribution initializers and renormalization utilities +for various initialization schemes. +""" + +from mindspore.common.initializer import Initializer, _register, _init_random_uniform, _assignment, TruncatedNormal, \ + Normal, HeNormal, HeUniform + + +@_register() +class Uniform(Initializer): + r""" + Generates an array with values sampled from Uniform distribution :math:`{U}(-\text{scale}, \text{scale})` in order + to initialize a tensor. + + Args: + scale (float): The bound of the Uniform distribution. Default: 1.0. + + + Examples: + >>> import mindspore + >>> from mindspore.common.initializer import initializer, Uniform + >>> tensor1 = initializer(Uniform(), [1, 2, 3], mindspore.float32) + >>> tensor2 = initializer('uniform', [1, 2, 3], mindspore.float32) + """ + + def __init__(self, scale=1.): + super(Uniform, self).__init__(scale=scale) + self.scale = scale + + def _initialize(self, arr): + tmp = _init_random_uniform(0., self.scale, arr.shape) + _assignment(arr, tmp) + + +def renormal_initializer(init_method): + """ + Normalize and convert initialization method to proper initializer instance. + + Args: + init_method (str or Initializer): The initialization method name or + an Initializer instance. Supported string values are: + 'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', + 'he_uniform', 'he_normal', 'xavier_uniform'. + + Returns: + Initializer: The corresponding initializer instance. + + Raises: + ValueError: If the initialization method is not supported. + """ + name_list = ['zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', 'xavier_uniform'] + if not init_method in name_list and not isinstance(init_method, Initializer): + raise ValueError( + f'initial method \"{init_method}\" is not supported.') + + if init_method == 'truncatedNormal': + init_method = TruncatedNormal(sigma=1.) + elif init_method == 'normal': + init_method = Normal(sigma=1.) + elif init_method == 'uniform': + init_method = Uniform() + elif init_method == 'he_normal': + init_method = HeNormal() + elif init_method == 'he_uniform': + init_method = HeUniform() + + return init_method diff --git a/mindscience/e3nn/utils/linalg.py b/mindscience/e3nn/utils/linalg.py new file mode 100644 index 0000000000000000000000000000000000000000..62ee37f2ecd4d9ca632c8144bb75b5ed314212fe --- /dev/null +++ b/mindscience/e3nn/utils/linalg.py @@ -0,0 +1,40 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Linear algebra utilities for matrix operations. + +This module provides utility functions for linear algebra operations, +including matrix direct sum and other matrix manipulation functions +commonly used in the e3nn library. +""" +from mindspore import ops + + +def _direct_sum(*matrices): + r"""Direct sum of matrices, put them in the diagonal + """ + front_indices = matrices[0].shape[:-2] + m = sum(x.shape[-2] for x in matrices) + n = sum(x.shape[-1] for x in matrices) + total_shape = list(front_indices) + [m, n] + zeros = ops.Zeros() + out = zeros(tuple(total_shape), matrices[0].dtype) + i, j = 0, 0 + for x in matrices: + m, n = x.shape[-2:] + out[..., i: i + m, j: j + n] = x + i += m + j += n + return out diff --git a/mindscience/e3nn/utils/ncon.py b/mindscience/e3nn/utils/ncon.py new file mode 100644 index 0000000000000000000000000000000000000000..f47d64bab3ee63fbd8358026ed441e4345f1e5dc --- /dev/null +++ b/mindscience/e3nn/utils/ncon.py @@ -0,0 +1,699 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ncon""" +from copy import deepcopy +import numpy as np + +from mindspore import ops, nn, vmap +from mindspore.numpy import tensordot, trace, expand_dims + + +def list_to_tuple(lst): + """list_to_tuple""" + return tuple(list_to_tuple(item) if isinstance(item, list) else item for item in lst) + + +def nest_vmap(fn, in_list, out_list, pt): + """nest vmap function""" + if pt == len(in_list) - 1: + return vmap(fn, in_list[pt], out_list[pt]) + return vmap(nest_vmap(fn, in_list, out_list, pt + 1), in_list[pt], out_list[pt]) + + +def _create_order(con_list): + """ Identify all unique, positive indices and return them sorted. """ + flat_con = np.concatenate(con_list) + return np.unique(flat_con[flat_con > 0]).tolist() + + +def _single_trace(con, leg): + """_single_trace""" + leg = np.where(np.array(con) == leg)[0] + con = np.delete(con, leg).tolist() + return con, leg.tolist() + + +def _find_sum(con_list): + """_find_sum + + Args: + con_list: con_list + + Returns: + legs + """ + flat = [] + for item in con_list: + flat += item + legs = [] + for leg in np.unique(flat): + if leg < 0: + continue + if np.sum(np.array(flat) == leg) == 1: + legs.append(leg) + return legs + + +def _find_trace(con_list): + """_find_trace + + Args: + con_list: con_list + + Returns: + legs_list + """ + legs_list = [] + for i in range(len(con_list)): + tr_num = len(con_list[i]) - len(np.unique(con_list[i])) + legs = [] + if tr_num: + for leg in np.unique(con_list[i]): + if sum(con_list[i] == leg) > 1 and leg > 0: + leg = np.where(con_list[i] == leg)[0].tolist() + legs.append(leg) + con_list[i] = np.delete(con_list[i], leg).tolist() + + legs_list.append(legs) + return legs_list + + +def _find_batch(con_list): + """_find_batch + + Args: + con_list: con_list + + Returns: + outer + """ + outer = [] + for i in con_list: + if not isinstance(i, np.ndarray): + i = np.array(i) + outer.extend(i[i < 0].tolist()) + if not outer: + return None + if -len(outer) == min(outer): + return None + + for leg in np.unique(outer): + if sum(outer == leg) == 1: + outer = np.delete(outer, outer.index(leg)).tolist() + + return outer + + +def _process_perm(con, batch_leg): + """_process_perm""" + p = list(range(len(con))) + for i, ind in enumerate(batch_leg): + j = con.index(ind) + if i == j: + continue + con[i], con[j] = con[j], con[i] + p[i], p[j] = p[j], p[i] + + return con, tuple(p) + + +def _make_dict(mode, + inds=None, + legs=None, + batch_leg=None, + p_list=None, + res_legs=None, + permute_index=None, + expand_axis=None): + """_summary_ + + Args: + mode: mode + inds: inds. Defaults to None. + legs: legs. Defaults to None. + batch_leg: batch_leg. Defaults to None. + p_list: p_list. Defaults to None. + res_legs: res_legs. Defaults to None. + permute_index: permute_index. Defaults to None. + expand_axis: expand_axis. Defaults to None. + + Raises: + ValueError: ValueError + + Returns: + d + """ + d = {} + calculate_mode = 'mode' + indices = 'inds' + indices_legs = 'legs' + d[calculate_mode] = mode + + if d[calculate_mode] == 'permute': + d['perms'] = p_list + + elif d[calculate_mode] == 'outer': + d[indices] = inds + + elif d[calculate_mode] in ('diag', 'sum', 'trace'): + d[indices] = inds + d[indices_legs] = legs + + elif d[calculate_mode] == 'ndot': + d[indices] = inds + d[indices_legs] = legs + d['batch_leg'] = batch_leg + + elif d[calculate_mode] == 'hadamard': + d[indices] = inds + d[indices_legs] = legs + d['res_legs'] = res_legs + d['permute_index'] = permute_index + d['expand_axis'] = expand_axis + + else: + raise ValueError + + return d + + +def _process_commands(con_list): + """_process_commands + + Args: + con_list: con_list + + Returns: + conmmands, operators + """ + conmmands = [] + operators = [] + + # find sum index + sum_legs = _find_sum(con_list) + for leg in sum_legs: + for i, con in enumerate(con_list): + if leg in con: + leg_ind = con.index(leg) + con_list[i].remove(leg) + conmmands.append(_make_dict('sum', [i], [leg_ind])) + operators.append(ops.sum) + + # find trace + trace_legs = _find_trace(con_list) + for i, leg_list in enumerate(trace_legs): + if leg_list: + for legs in leg_list: + conmmands.append(_make_dict('trace', [i], legs)) + operators.append(trace) + + order = _create_order(con_list) + batch_legs = _find_batch(con_list) + + if not con_list[0]: + return conmmands, operators + + do_ndot(con_list, conmmands, operators, order, batch_legs) + + # do Hadamard(alike) product + do_hadamard(con_list, conmmands, operators) + + # do outer product + for i, con in enumerate(con_list): + if not i: + continue + if con: + inds = [0, i] + for leg in con: + con_list[0].append(leg) + con_list[i] = [] + conmmands.append(_make_dict('outer', inds)) + operators.append(tensordot) + + # do diagonal + min_leg = min(con_list[0]) + for leg in range(-1, min_leg - 1, -1): + num_leg = con_list[0].count(leg) + while num_leg > 1: + i = con_list[0].index(leg) + j = con_list[0].index(leg, i + 1) + conmmands.append(_make_dict('diag', [0], [i, j])) + operators.append(ops.diagonal) + con_list[0] = con_list[0][:i] + con_list[0][i + 1:j] + con_list[0][j + 1:] + [leg] + num_leg = con_list[0].count(leg) + + # do final permutation + fin_con = list(range(-1, -1 - len(con_list[0]), -1)) + con_list[0], p = _process_perm(con_list[0], fin_con) + conmmands.append(_make_dict('permute', p_list=[p])) + operators.append(ops.permute) + + return conmmands, operators + + +def do_ndot(con_list, conmmands, operators, order, batch_legs): + """do_ndot + + Args: + con_list: con_list + conmmands: conmmands + operators: operators + order: order + batch_legs: batch_legs + """ + while order: + leg_now = order[-1] + inds = [] + legs = [] + batch_legs_now = [] + + # find the two tensors' indices + for i, item in enumerate(con_list): + if leg_now in item: + inds.append(i) + + # check trace + if len(inds) == 1: + con_list[inds[0]], legs = _single_trace(con_list[inds[0]], leg_now) + conmmands.append(_make_dict('trace', inds, legs)) + operators.append(trace) + + else: + # find batch legs + batch_leg_inds = [] + if batch_legs is not None: + tmp = np.intersect1d(con_list[inds[0]], con_list[inds[1]]) + batch_legs_now = np.intersect1d(tmp, batch_legs, False).tolist() + + # find indices of batch legs + for batch_leg in batch_legs_now: + i_leg_0 = con_list[inds[0]].index(batch_leg) + i_leg_1 = con_list[inds[1]].index(batch_leg) + con_list[inds[0]].remove(batch_leg) + con_list[inds[1]].remove(batch_leg) + batch_leg_inds.append((i_leg_0, i_leg_1, None)) + + ndot_legs = [] + ndot_leg_inds = [] + # find all ndot legs and their indices + for leg in con_list[inds[0]]: + if leg in con_list[inds[1]]: + i_leg_0 = con_list[inds[0]].index(leg) + i_leg_1 = con_list[inds[1]].index(leg) + ndot_legs.append(leg) + ndot_leg_inds.append([i_leg_0, i_leg_1]) + + # do ndot contraction and update order + for leg in ndot_legs: + con_list[inds[0]].remove(leg) + con_list[inds[1]].remove(leg) + for leg in ndot_legs: + if leg != leg_now: + order.remove(leg) + + ndot_leg_inds = ndot_leg_inds[0] if len(ndot_leg_inds) == 1 else np.array( + ndot_leg_inds).transpose().tolist() + conmmands.append(_make_dict('ndot', inds, list_to_tuple(ndot_leg_inds), batch_leg_inds)) + operators.append( + nest_vmap(tensordot, batch_leg_inds, [0] * len(batch_leg_inds), 0) if batch_leg_inds else tensordot) + + # merge two con_list + for leg in con_list[inds[1]]: + if leg not in batch_legs_now: + con_list[inds[0]].append(leg) + con_list[inds[1]] = [] + con_list[inds[0]] = batch_legs_now + con_list[inds[0]] + + order = order[:-1] + + +def do_hadamard(con_list, conmmands, operators): + """do_hadamard + + Args: + con_list: con_list + conmmands: conmmands + operators: operators + """ + is_con_list_not_none = len(con_list) == 2 and con_list[1] + if is_con_list_not_none and not [i for i in con_list[0] if i > 0] and not [i for i in con_list[1] if i > 0]: + con_list_all = [] + for con in con_list: + con_list_all.extend(con) + con_min_leg = min(con_list_all) + out_list = [i for i in range(-1, con_min_leg - 1, -1)] + + res_legs = [] + for ind in out_list: + for i, con in enumerate(con_list): + if ind in con: + res_legs.append((i, con.index(ind))) + break + + hadamard_legs = [[], []] + con_raw = deepcopy(con_list) + handle_inds(con_list, out_list, hadamard_legs) + + expand_axis = deepcopy(hadamard_legs) + for i, axis in enumerate(expand_axis): + if axis and len(axis) <= 1: + expand_axis[i] = axis[0] + + # input permute + permute_index = [[], []] + con_sort = deepcopy(con_raw) + for i, con in enumerate(con_raw): + con_sort[i].sort(reverse=True) + _, permute_index[i] = _process_perm(con, con_sort[i]) + + conmmands.append( + _make_dict('hadamard', + inds=[0, 1], + legs=hadamard_legs, + res_legs=res_legs, + permute_index=permute_index, + expand_axis=expand_axis)) + operators.append([ops.permute, ops.tile, ops.mul, expand_dims]) + + +def handle_inds(con_list, out_list, hadamard_legs): + """handle_inds""" + for i, con in enumerate(con_list): + if con: + for ind in out_list: + if ind not in con: + hadamard_legs[i].append((out_list.index(ind))) + if i: + con_list[i] = [] + else: + con_list[i] = out_list + + +class Ncon(nn.Cell): + r""" + Multiple-tensor contraction operator which has similar function to Einsum. + + Args: + con_list (List[List[int]]): lists of indices for each tensor. + The number of each list in `con_list` should coincide with the corresponding tensor's dimensions. + The positive indices indicate the dimensions to be contracted or summed. + The negative indices indicate the dimensions to be keeped (as batch dimensions). + + Inputs: + - **input** (List[Tensor]) - Tensor List. + + Outputs: + - **output** (Tensor) - The shape of tensor depends on the input and the computation process. + + Raises: + ValueError: If the number of commands is not match the number of operations. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore import ops + >>> from mindchemistry.e3.utils import Ncon + Trace of a matrix: + >>> a = ops.ones((3, 3)) + >>> Ncon([[1, 1]])([a]) + 3.0 + Diagonal of a matrix: + >>> Ncon([[-1, -1]])([a]) + [1. 1. 1.] + Outer product: + >>> b = ops.ones((2)) + >>> c = ops.ones((3)) + >>> Ncon([[-1], [-2]])([b, c]).shape + (2, 3) + Batch matrix multiplication + >>> d = ops.ones((2, 3, 4)) + >>> e = ops.ones((2, 4, 1)) + >>> Ncon([[-1, -2, 1], [-1, 1, -3]])([d, e]).shape + (2, 3, 1) + """ + + def __init__(self, con_list): + super().__init__() + self.con_list = tuple(con_list) + con_list_copy = deepcopy(con_list) + self.commands, self.ops = _process_commands(con_list_copy) + if len(self.commands) != len(self.ops): + raise ValueError(f'{self.commands} is not match {len(self.ops)}') + + def construct(self, ten_list): + """ + The list of tensors to be conctracted. + """ + i = 0 + for d in self.commands: + if d['mode'] == 'diag': + ten_list[0] = self.ops[i](ten_list[0], 0, *d['legs']) + elif d['mode'] == 'permute': + ten_list[0] = self.ops[i](ten_list[0], d['perms'][0]) + elif d['mode'] == 'sum': + i1 = d['inds'][0] + ten_list[i1] = self.ops[i](ten_list[i1], d['legs'][0]) + elif d['mode'] == 'trace': + i1 = d['inds'][0] + ten_list[i1] = self.ops[i](ten_list[i1], 0, d['legs'][0], d['legs'][1]) + elif d['mode'] == 'outer': + i1, i2 = d['inds'] + ten_list[i1] = self.ops[i](ten_list[i1], ten_list[i2], 0) + elif d['mode'] == 'ndot': + i1, i2 = d['inds'] + ten_list[i1] = self.ops[i](ten_list[i1], ten_list[i2], d['legs']) + elif d['mode'] == 'hadamard': + i1, i2 = d['inds'] + a = ten_list[i1] + b = ten_list[i2] + res_legs = d['res_legs'] + + a = ops.permute(a, d['permute_index'][i1]) + b = ops.permute(b, d['permute_index'][i2]) + + if d['expand_axis'][i1]: + a = expand_dims(a, d['expand_axis'][i1]) + if d['expand_axis'][i2]: + b = expand_dims(b, d['expand_axis'][i2]) + + tile_index = [[1 for _ in res_legs], [1 for _ in res_legs]] + for j in range(len(d['legs'][i1])): + tile_index[0][d['legs'][i1][j]] = ten_list[res_legs[d['legs'][i1][j]][0]].shape[res_legs[ + d['legs'][i1][j]][1]] + for j in range(len(d['legs'][i2])): + tile_index[1][d['legs'][i2][j]] = ten_list[res_legs[d['legs'][i2][j]][0]].shape[res_legs[ + d['legs'][i2][j]][1]] + a = ops.tile(a, tuple(tile_index[0])) + b = ops.tile(b, tuple(tile_index[1])) + + ten_list[i1] = ops.mul(a, b) + else: + i += 1 + continue + i += 1 + return ten_list[0] + + def __repr__(self): + s = f'Ncon: {self.con_list}\n' + for d in self.commands: + s += str(d) + '\n' + return s + + +def test_other(): + """test_other""" + ncon = Ncon([[5, -1, 1, 4, 3, -2], [3, -2, -1, 4, 2], [2, -3], [-3, -4]]) + v1 = ops.ones((3, 1, 3, 4, 5, 2)) + v2 = ops.ones((5, 2, 1, 4, 6)) + v3 = ops.ones((6, 3)) + v4 = ops.ones((3, 4)) + print(ncon) + out = ncon([v1, v2, v3, v4]) + print(out.shape) + + ncon = Ncon([[-1, 2], [-1, 1], [2, 1, -2]]) + v1 = ops.ones((20, 50)) + v2 = ops.ones((20, 2)) + v3 = ops.ones((50, 2, 7)) + print(ncon) + out = ncon([v1, v2, v3]) + print(out.shape) + + ncon = Ncon([[-1, -2, 1], [-1, 1]]) + v1 = ops.ones((3, 4, 5)) + v2 = ops.ones((3, 5)) + print(ncon) + out = ncon([v1, v2]) + print(out.shape) + + +def test_diagonal(): + """test_diagonal""" + ncon = Ncon([[-1, -1]]) + v1 = ops.ones((3, 3)) + print(ncon) + out = ncon([v1]) + print(out.shape) + print(out) + + +def test_outer(): + """test_other""" + ncon = Ncon([[-1], [-2]]) + v1 = ops.ones((2)) + v2 = ops.ones((3)) + print(ncon) + out = ncon([v1, v2]) + print(out.shape) + print(out) + + +def test_outer_multi_input(): + """test_other""" + ncon = Ncon([[-1], [-2], [-3]]) + v1 = ops.ones((2)) + v2 = ops.ones((3)) + v3 = ops.ones((4)) + print(ncon) + out = ncon([v1, v2, v3]) + print(out.shape) + print(out) + + +def test_ndot(): + """test_other""" + ncon = Ncon([[-1, -2, 1], [-1, 1]]) + v1 = ops.ones((3, 4, 5)) + v2 = ops.ones((3, 5)) + print(ncon) + out = ncon([v1, v2]) + print(out.shape) + print(out) + + +def test_ndot_2(): + """test_other""" + ncon = Ncon([[-1, -2, 1, 2], [-1, 1, 2]]) + v1 = ops.ones((3, 4, 5, 6)) + v2 = ops.ones((3, 5, 6)) + print(ncon) + out = ncon([v1, v2]) + print(out.shape) + print(out) + + +def test_hadamard(): + """test_hadamard""" + a = np.arange(6).reshape((2, 3)) + b = np.arange(6).reshape((2, 3)) + print(a) + print(b) + einstr = f"zu,zu->zu" + d = np.einsum(einstr, a, b) + print(d) + print(d.shape) + + ma = ms.Tensor(a, dtype=ms.float32) + mb = ms.Tensor(b, dtype=ms.float32) + ncon = Ncon([[-1, -2], [-1, -2]]) + print(ncon) + md = ncon([ma, mb]) + print(md.shape) + print(np.allclose(md.asnumpy(), d)) + + +def test_hadamard_alike(): + """test_hadamard_alike""" + a = np.arange(8).reshape((2, 4)) + b = np.arange(24).reshape((2, 3, 4)) + print(a) + print(b) + einstr = f"zi,zui->zui" + d = np.einsum(einstr, a, b) + print(d) + print(d.shape) + + ma = ms.Tensor(a, dtype=ms.float32) + mb = ms.Tensor(b, dtype=ms.float32) + ncon = Ncon([[-1, -3], [-1, -2, -3]]) + print(ncon) + md = ncon([ma, mb]) + print(md.shape) + print(np.allclose(md.asnumpy(), d)) + + +def test_hadamard_with_outer(): + """test_hadamard_with_outer""" + a = np.arange(24).reshape((2, 3, 4)) + b = np.arange(30).reshape((2, 3, 5)) + print(f"a:\n {a}") + print(f"b:\n {b}") + + einstr = f"zui,zuj->zuij" + + d = np.einsum(einstr, a, b) + print(f"d:\n {d}") + print(f"d.shape:\n {d.shape}") + + ma = ms.Tensor(a, dtype=ms.float32) + mb = ms.Tensor(b, dtype=ms.float32) + + ncon = Ncon([[-1, -2, -3], [-1, -2, -4]]) + print(ncon) + md = ncon([ma, mb]) + print(md.shape) + print(np.allclose(md.asnumpy(), d)) + + +def test_hadamard_outer_nosequential(): + """test_hadamard_outer_nosequential""" + a = np.arange(8).reshape((2, 4)) + b = np.arange(30).reshape((2, 5, 3)) + print(f"a:\n {a}") + print(f"b:\n {b}") + + einstr = f"ac,adb->abcd" + + d = np.einsum(einstr, a, b) + print(f"d:\n {d}") + print(f"d.shape:\n {d.shape}") + ma = ms.Tensor(a, dtype=ms.float32) + mb = ms.Tensor(b, dtype=ms.float32) + + ncon = Ncon([[-1, -3], [-1, -4, -2]]) + print(ncon) + md = ncon([ma, mb]) + print(md.shape) + print(np.allclose(md.asnumpy(), d)) + + +def test_sum(): + """test_other""" + ncon = Ncon([[1, 2]]) + v1 = ops.ones((2, 3)) + print(ncon) + out = ncon([v1]) + print(out.shape) + print(out) + + +if __name__ == '__main__': + import mindspore as ms + + ms.set_context(device_target="GPU", device_id=4, mode=ms.GRAPH_MODE, save_graphs=False) + np.random.seed(123) + + test_hadamard_outer_nosequential() diff --git a/mindscience/e3nn/utils/perm.py b/mindscience/e3nn/utils/perm.py new file mode 100644 index 0000000000000000000000000000000000000000..ca94ae3178491aef7a34c14a31208ae66c6faa4c --- /dev/null +++ b/mindscience/e3nn/utils/perm.py @@ -0,0 +1,178 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""permutation operators""" +import random +import math + + +def _is_perm(p): + return sorted(set(p)) == list(range(len(p))) + + +def _identity(n): + return tuple(i for i in range(n)) + + +def _compose(p1, p2): + r""" + compute p1 . p2 + p: i |-> p[i] + [p1.p2](i) = p1(p2(i)) = p1[p2[i]] + """ + assert _is_perm(p1) and _is_perm(p2) + assert len(p1) == len(p2) + + return tuple(p1[p2[i]] for i in range(len(p1))) + + +def _inverse(p): + r""" + compute the inverse permutation + """ + return tuple(p.index(i) for i in range(len(p))) + + +def _rand(n): + i = random.randint(0, math.factorial(n) - 1) + return _from_int(i, n) + + +def _from_int(i, n): + pool = list(range(n)) + p = [] + for _ in range(n): + j = i % n + i = i // n + p.append(pool.pop(j)) + n -= 1 + return tuple(p) + + +def _to_int(p): + """ + Convert a permutation to its integer representation. + + Args: + p (tuple): A permutation represented as a tuple. + + Returns: + int: The integer representation of the permutation. + """ + n = len(p) + pool = list(range(n)) + i = 0 + m = 1 + for j in p: + k = pool.index(j) + i += k * m + m *= len(pool) + pool.pop(k) + return i + + +def _group(n): + return {_from_int(i, n) for i in range(math.factorial(n))} + + +def _germinate(subset): + """ + Generate the group closure of a subset of permutations. + + Args: + subset (set): A set of permutations. + + Returns: + set: The group closure containing all permutations that can be + generated from the input subset through composition and inversion. + """ + while True: + n = len(subset) + subset = subset.union([_inverse(p) for p in subset]) + subset = subset.union([ + _compose(p1, p2) + for p1 in subset + for p2 in subset + ]) + if len(subset) == n: + return subset + + +def _is__(g): + """ + Check if a set of permutations forms a group. + + Args: + g (set): A set of permutations to check. + + Returns: + bool: True if the set forms a group, False otherwise. + """ + if not g: + return False + + n = len(next(iter(g))) + + for p in g: + assert len(p) == n, p + + if not _identity(n) in g: + return False + + for p in g: + if not _inverse(p) in g: + return False + + for p1 in g: + for p2 in g: + if not _compose(p1, p2) in g: + return False + + return True + + +def _to_cycles(p): + """ + Convert a permutation to its cycle representation. + + Args: + p (tuple): A permutation represented as a tuple. + + Returns: + set: A set of tuples representing the cycles of the permutation. + Only cycles of length >= 2 are included. + """ + n = len(p) + + cycles = set() + + for i in range(n): + c = [i] + while p[i] != c[0]: + i = p[i] + c.append(i) + if len(c) >= 2: + i = c.index(min(c)) + c = c[i:] + c[:i] + cycles.add(tuple(c)) + + return cycles + + +def _sign(p): + s = 1 + for c in _to_cycles(p): + if len(c) % 2 == 0: + s = -s + return s diff --git a/mindscience/e3nn/utils/radius.py b/mindscience/e3nn/utils/radius.py new file mode 100644 index 0000000000000000000000000000000000000000..b6cf2cd5ff6282da7d46ff55e2f419500207c1f3 --- /dev/null +++ b/mindscience/e3nn/utils/radius.py @@ -0,0 +1,248 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""radius""" +from scipy.spatial import cKDTree +import numpy as np + + +def _reshape_and_batch(x, batch_x): + """_reshape_and_batch""" + if x.ndim > 2: + if batch_x is None: + batch_x = np.broadcast_to(np.arange(0, x.shape[0]).reshape(-1, 1), (x.shape[0], x.shape[1])).flatten() + x = x.reshape(-1, x.shape[-1]) + else: + if batch_x is None: + batch_x = np.zeros(x.shape[0], dtype=x.dtype) + x = x.reshape((-1, 1)) if x.ndim == 1 else x + + return x, batch_x.astype(np.int64) + + +def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): + r""" + Find all points in `x` for each element in `y` within distance `r`. + + Args: + x (ndarray): node feature matrix of x. + y (ndarray): node feature matrix of y. + r (ndarray, float): the radius. + batch_x (ndarray): batch vector of x. If it is none, then calculate based on x and return. Default: ``None``. + batch_y (ndarray): batch vector of y. If it is none, then calculate based on y and return. Default: ``None``. + max_num_neighbors (int): The maximum number of neighbors to return for each element in `y`. Dufault: ``32``. + + Returns: + edge_index (numpy.ndarray) - including edges of source and destination. + + batch_x (numpy.ndarray) - batch vector of x. + + batch_y (numpy.ndarray) - batch vector of y. + + Raises: + ValueError: If the last dimension of `x` and `y` do not match. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.utils import radius + >>> import numpy as np + >>> np.random.seed(1) + >>> x = np.random.random((5, 12, 3)) + >>> r = 0.5 + >>> edge_index, batch_x, batch_y = radius(x, x, r) + >>> print(edge_index.shape) + (2, 222) + >>> print(batch_x.shape) + (60,) + >>> print(batch_y.shape) + (60,) + + """ + if not x.shape[-1] == y.shape[-1]: + raise ValueError(f"Feature size do not match.") + if max_num_neighbors < 1: + raise Warning(f'max_num_neighbors: {max_num_neighbors}') + + x, batch_x = _reshape_and_batch(x, batch_x) + y, batch_y = _reshape_and_batch(y, batch_y) + + x = np.concatenate((x, 2 * r * batch_x.reshape(-1, 1).astype(x.dtype)), axis=-1) + y = np.concatenate((y, 2 * r * batch_y.reshape(-1, 1).astype(y.dtype)), axis=-1) + + tree = cKDTree(x) + _, col = tree.query(y, k=max_num_neighbors, distance_upper_bound=r + 1e-8) + row = [np.full_like(c, i) for i, c in enumerate(col)] + col = col.flatten() + row = np.concatenate(row, axis=0) + mask = col < int(tree.n) + + return np.stack([row[mask], col[mask]], axis=0), batch_x, batch_y + + +# pylint: disable=C0103 +# pylint: disable=W0612 +def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32, flow='source_to_target'): + r""" + Computes graph edges to all points within a given distance. + + Args: + x (ndarray): node feature matrix. + r (ndarray, float): the radius. + batch (Tensor): batch vector. If it is none, then calculate and return. Default: ``None``. + loop (bool): whether contain self-loops in the graph. Dufault: ``False``. + max_num_neighbors (int): The maximum number of neighbors to return for each element in `y`. Dufault: ``32``. + flow (str): {'source_to_target', 'target_to_source'}, the flow direction when using in combination with + message passing. Dufault: ``'source_to_target'``. + + Returns: + edge_index (ndarray) - including edges of source and destination. + + batch (ndarray) - batch vector. + + Raises: + ValueError: If `flow` is not in {'source_to_target', 'target_to_source'}. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.utils import radius_graph + >>> import numpy as np + >>> np.random.seed(1) + >>> x = np.random.random((5, 12, 3)) + >>> r = 0.5 + >>> edge_index, batch = radius_graph(x, r) + >>> print(edge_index.shape) + (2, 162) + >>> print(batch.shape) + (60,) + """ + + if flow not in ['source_to_target', 'target_to_source']: + raise ValueError(f'`flow` should be in ["source_to_target", "target_to_source"].') + (row, col), batch, _ = radius(x, x, r, batch, batch, max_num_neighbors + 1) + row, col = (col, row) if flow == 'source_to_target' else (row, col) + if not loop: + mask = row != col + row, col = row[mask], col[mask] + return np.stack([row, col], axis=0), batch + + +def radius_full(x, y, batch_x=None, batch_y=None): + r""" + Find all points in `x` for each element in `y`. + + Args: + x (Tensor): node feature matrix. + y (Tensor): node feature matrix. + batch_x (ndarray): batch vector of x. If it is none, then calculate based on x and return. Default: ``None``. + batch_y (ndarray): batch vector of y. If it is none, then calculate based on y and return. Default: ``None``. + + Returns: + edge_index (numpy.ndarray) - including edges of source and destination. + + batch_x (numpy.ndarray) - batch vector of x. + + batch_y (numpy.ndarray) - batch vector of y. + + Raises: + ValueError: If the last dimension of `x` and `y` do not match. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.utils import radius_full + >>> from mindspore import ops, Tensor + >>> x = Tensor(ops.ones((5, 12, 3))) + >>> edge_index, batch_x, batch_y = radius_full(x, x) + >>> print(edge_index.shape) + (2, 720) + >>> print(batch_x.shape) + (60,) + >>> print(batch_y.shape) + (60,) + + """ + if not x.shape[-1] == y.shape[-1]: + raise ValueError(f"Feature size do not match.") + + if x.ndim > 2 and y.ndim > 2: + b_x, b_y = x.shape[0], y.shape[0] + len_x, len_y = x.shape[1], y.shape[1] + else: + b_x, b_y = 1, 1 + len_x, len_y = x.shape[0], y.shape[0] + + x, batch_x = _reshape_and_batch(x, batch_x) + y, batch_y = _reshape_and_batch(y, batch_y) + + batch_unique = np.unique(batch_x) + _row = [] + edge_dst = [] + for i in batch_unique: + _row.extend(np.arange(len_y) + i * len_y) + _col = np.arange(len_x) + i * len_x + edge_dst.extend(np.broadcast_to(_col, (len_y, len_x)).flatten()) + edge_src = np.broadcast_to(np.array(_row).reshape(-1, 1), (len(_row), len_x)).flatten() + edge_dst = np.array(edge_dst) + + return np.stack([edge_src, edge_dst]), batch_x, batch_y + + +def radius_graph_full(x, batch=None, loop=False, flow='source_to_target'): + r""" + Computes graph edges to all points within a given distance. + + Args: + x (Tensor): node feature matrix. + batch (Tensor): batch vector. If it is none, then calculate and return. Default: ``None``. + loop (bool): whether contain self-loops in the graph. Dufault: ``False``. + flow (str): {'source_to_target', 'target_to_source'}, the flow direction when using in combination with + message passing. Dufault: ``'source_to_target'``. + + Returns: + edge_index (ndarray) - including edges of source and destination. + + batch (ndarray) - batch vector. + + Raises: + ValueError: If `flow` is not in {'source_to_target', 'target_to_source'}. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.utils import radius_graph_full + >>> from mindspore import ops, Tensor + >>> x = Tensor(ops.ones((5, 12, 3))) + >>> edge_index, batch = radius_graph_full(x) + >>> print(edge_index.shape) + (2, 660) + >>> print(batch.shape) + (60,) + + """ + if flow not in ['source_to_target', 'target_to_source']: + raise ValueError(f'`flow` should be in ["source_to_target", "target_to_source"].') + + (row, col), batch, _ = radius_full(x, x, batch, batch) + row, col = (col, row) if flow == 'source_to_target' else (row, col) + if not loop: + mask = row != col + row, col = row[mask], col[mask] + + return np.stack([row, col], axis=0), batch diff --git a/tests/e3nn/nn/test_activation.py b/tests/e3nn/nn/test_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..9b585d1097aa7717f91b82c996ce2fea39ba2e9e --- /dev/null +++ b/tests/e3nn/nn/test_activation.py @@ -0,0 +1,109 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for e3nn.nn.activation module - Core functionality only""" + +import pytest +import numpy as np +from mindspore import Tensor, ops, float32 +from mindscience.e3nn.nn.activation import Activation, _Normalize, _moment, _parity_function +from mindscience.e3nn.o3 import Irreps + + +class TestActivation: + """Core tests for Activation class""" + + def test_activation_basic_creation(self): + """Test basic Activation creation and forward pass""" + act = Activation('2x0e+1x0o', [ops.tanh, ops.abs]) + + x = Tensor(np.random.randn(3, 3), dtype=float32) + output = act(x) + + assert output.shape == (3, 3) + assert act.irreps_in == Irreps('2x0e+1x0o') + assert not np.any(np.isnan(output.asnumpy())) + + def test_activation_parity_change(self): + """Test activation function changes parity correctly""" + # abs function should change odd to even + act = Activation('2x0o', [ops.abs]) + + x = Tensor(np.random.randn(2, 2), dtype=float32) + output = act(x) + + assert act.irreps_out == Irreps('2x0e') # odd -> even + assert output.shape == (2, 2) + + def test_activation_invalid_non_scalar(self): + """Test activation with non-scalar irrep raises error""" + with pytest.raises(ValueError, match="non-scalar input"): + Activation('1x1e', [ops.tanh]) + + +class TestNormalize: + """Core tests for _Normalize class""" + + def test_normalize_basic(self): + """Test _Normalize normalizes activation function""" + norm_tanh = _Normalize(ops.tanh) + + x = Tensor(np.random.randn(100), dtype=float32) + output = norm_tanh(x) + + assert output.shape == x.shape + assert hasattr(norm_tanh, 'factor') + + def test_normalize_scaling_function(self): + """Test _Normalize correctly handles scaling functions""" + def scale_func(x): + return x * 2.0 # This will have second moment = 4.0 + + norm_func = _Normalize(scale_func) + + # Verify factor is approximately correct (should be around 1/sqrt(4) = 0.5) + expected_factor = 1.0 / np.sqrt(4.0) + assert abs(float(norm_func.factor) - expected_factor) < 5e-3 + + # Test normalization effect + x = Tensor(np.ones(5), dtype=float32) + output = norm_func(x) + expected_output = scale_func(x) * norm_func.factor + assert np.allclose(output.asnumpy(), expected_output.asnumpy(), atol=1e-4) + + +class TestUtilityFunctions: + """Core tests for utility functions""" + + def test_moment_calculation(self): + """Test _moment function calculates moments correctly""" + moment = _moment(ops.tanh, 2) + + assert isinstance(moment, Tensor) + assert moment.shape == () # scalar + assert moment.asnumpy() > 0 + + def test_parity_function_detection(self): + """Test _parity_function detects function parity""" + # Test even function + parity_even = _parity_function(lambda x: x**2) + assert parity_even == 1 # even function + + # Test odd function + parity_odd = _parity_function(lambda x: x) + assert parity_odd == -1 # odd function + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/e3nn/nn/test_batchnorm.py b/tests/e3nn/nn/test_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..3636bd718d8d4e56c36f28d30ba7b3a1ba195666 --- /dev/null +++ b/tests/e3nn/nn/test_batchnorm.py @@ -0,0 +1,142 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for e3nn.nn.batchnorm module - Core functionality""" + +import pytest +import numpy as np +from mindspore import Tensor, float32 +from mindscience.e3nn.nn.batchnorm import BatchNorm +from mindscience.e3nn.o3 import Irreps + + +class TestBatchNorm: + """Core tests for BatchNorm class""" + + def test_batchnorm_basic_creation(self): + """Test basic BatchNorm creation and forward pass""" + bn = BatchNorm('2x0e+1x0o') + + x = Tensor(np.random.randn(4, 3), dtype=float32) + output = bn(x) + + assert output.shape == (4, 3) + assert bn.irreps == Irreps('2x0e+1x0o') + assert not np.any(np.isnan(output.asnumpy())) + + def test_batchnorm_normalization_correctness(self): + """Test that BatchNorm actually normalizes the data correctly""" + bn = BatchNorm('2x0e', eps=1e-8, affine=False) + + # Create data with known statistics + x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) + output = bn(x) + + # Manually compute expected normalized output + x_np = x.asnumpy() + x_mean = np.mean(x_np, axis=0) # [4.0, 5.0] + x_var = np.var(x_np, axis=0, ddof=0) # [5.0, 5.0] + expected_output = (x_np - x_mean) / np.sqrt(x_var + 1e-8) + + # Check that actual output matches manual calculation + output_np = output.asnumpy() + assert np.allclose(output_np, expected_output, atol=1e-6), \ + f"Normalization calculation incorrect" + + # Verify normalized output has zero mean and unit variance + assert abs(np.mean(output_np)) < 1e-6, "Mean should be close to 0" + assert abs(np.var(output_np, ddof=0) - 1.0) < 1e-5, "Variance should be close to 1" + + def test_batchnorm_affine_parameters(self): + """Test affine parameters (weight and bias) effect""" + x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) + + # Test with affine=True + bn_affine = BatchNorm('2x0e', affine=True, eps=1e-8) + weight = Tensor([2.0, 0.5], dtype=float32) + bias = Tensor([1.0, -1.0], dtype=float32) + bn_affine.weight.set_data(weight) + bn_affine.bias.set_data(bias) + + output_affine = bn_affine(x) + + # Verify computation: output = weight * normalized_input + bias + x_np = x.asnumpy() + x_mean = np.mean(x_np, axis=0) + x_var = np.var(x_np, axis=0, ddof=0) + x_normalized = (x_np - x_mean) / np.sqrt(x_var + 1e-8) + expected_output = x_normalized * weight.asnumpy() + bias.asnumpy() + + assert np.allclose(output_affine.asnumpy(), expected_output, atol=1e-5), \ + "Affine transformation calculation incorrect" + + # Test with affine=False + bn_no_affine = BatchNorm('2x0e', affine=False, eps=1e-8) + output_no_affine = bn_no_affine(x) + + assert np.allclose(output_no_affine.asnumpy(), x_normalized, atol=1e-5), \ + "Non-affine normalization calculation incorrect" + + def test_batchnorm_training_inference_modes(self): + """Test difference between training and inference modes""" + bn = BatchNorm('2x0e', momentum=0.1, instance=False, affine=False) + x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) + + # Training mode - should update running statistics + bn.training = True + output_train = bn(x) + + # Verify running statistics update follows momentum formula + x_np = x.asnumpy() + batch_mean = np.mean(x_np, axis=0) + batch_var = np.var(x_np, axis=0, ddof=0) + expected_running_mean = 0.9 * 0.0 + 0.1 * batch_mean # initial mean is 0 + expected_running_var = 0.9 * 1.0 + 0.1 * batch_var # initial var is 1 + + assert np.allclose(bn.running_mean.asnumpy(), expected_running_mean, atol=1e-6), \ + "Running mean update calculation incorrect" + assert np.allclose(bn.running_var.asnumpy(), expected_running_var, atol=1e-6), \ + "Running var update calculation incorrect" + + # Inference mode - should not update running statistics + running_mean_before = bn.running_mean.asnumpy().copy() + running_var_before = bn.running_var.asnumpy().copy() + + bn.training = False + output_inference = bn(x) + + assert np.allclose(bn.running_mean.asnumpy(), running_mean_before), \ + "Running mean should not change in inference mode" + assert np.allclose(bn.running_var.asnumpy(), running_var_before), \ + "Running var should not change in inference mode" + assert not np.any(np.isnan(output_train.asnumpy())) + assert not np.any(np.isnan(output_inference.asnumpy())) + + def test_batchnorm_invalid_parameters(self): + """Test error handling for invalid parameters""" + # Test invalid normalization + with pytest.raises(ValueError, match="Invalid normalization option"): + bn = BatchNorm('2x0e', normalization='invalid') + x = Tensor(np.random.randn(4, 2), dtype=float32) + bn(x) + + # Test invalid reduce + with pytest.raises(ValueError, match="Invalid reduce option"): + bn = BatchNorm('2x0e', reduce='invalid') + x = Tensor(np.random.randn(4, 2), dtype=float32) + bn(x) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/e3nn/nn/test_fc.py b/tests/e3nn/nn/test_fc.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3d9780b8b455871f85c55ce6a9a7d182887b5f --- /dev/null +++ b/tests/e3nn/nn/test_fc.py @@ -0,0 +1,100 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for FullyConnectedNet""" +import pytest +import numpy as np +from mindspore import Tensor, ops +from mindscience.e3nn.nn.fc import FullyConnectedNet + + +class TestFullyConnectedNet: + """Test cases for FullyConnectedNet""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_fc_basic_creation(self): + """Test basic creation and parameter initialization""" + h_list = [4, 10, 6] + fc = FullyConnectedNet(h_list) + + assert fc.h_list == h_list + assert len(fc.layer_list) == 2 + assert fc.layer_list[0].h_in == 4 and fc.layer_list[0].h_out == 10 + assert fc.layer_list[1].h_in == 10 and fc.layer_list[1].h_out == 6 + assert fc.weight_numel == 4*10 + 10*6 + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_fc_forward_computation(self): + """Test forward propagation computation correctness""" + h_list = [3, 4, 2] + fc = FullyConnectedNet(h_list, act=None, out_act=False) + + x = Tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + + # Set fixed weights for verification + fc.layer_list[0].weight.set_data(Tensor(np.array([ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2] + ], dtype=np.float32))) + + fc.layer_list[1].weight.set_data(Tensor(np.array([ + [0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6], + [0.7, 0.8] + ], dtype=np.float32))) + + output = fc(x) + + # Manual calculation verification + w1_norm = fc.layer_list[0].weight.asnumpy() / np.sqrt(3) + hidden = np.dot(x.asnumpy(), w1_norm) + w2_norm = fc.layer_list[1].weight.asnumpy() / np.sqrt(4) + expected_output = np.dot(hidden, w2_norm) + + assert output.shape == (2,) + assert np.allclose(output.asnumpy(), expected_output, atol=1e-6) + + def test_fc_activation_function(self): + """Test activation function""" + h_list = [2, 3, 2] + fc_with_act = FullyConnectedNet(h_list, act=ops.tanh, out_act=True) + fc_without_act = FullyConnectedNet(h_list, act=ops.tanh, out_act=False) + + x = Tensor(np.array([1.0, -1.0], dtype=np.float32)) + output_with_act = fc_with_act(x) + output_without_act = fc_without_act(x) + + assert output_with_act.shape == (2,) + assert output_without_act.shape == (2,) + assert not np.allclose(output_with_act.asnumpy(), output_without_act.asnumpy()) + + def test_fc_error_handling(self): + """Test error handling""" + # Test invalid h_list + with pytest.raises(TypeError): + FullyConnectedNet([3.5, 4, 2]) + + # Test minimum valid case + fc_minimal = FullyConnectedNet([2, 1]) + assert len(fc_minimal.layer_list) == 1 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/e3nn/nn/test_gate.py b/tests/e3nn/nn/test_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..89ef2ca22a7c1205188908ec6da0955e06c6d36c --- /dev/null +++ b/tests/e3nn/nn/test_gate.py @@ -0,0 +1,53 @@ +"""Test Gate module""" +import pytest +from mindspore import Tensor, ops, float32 +import numpy as np +from mindscience.e3nn.nn import Gate + + +class TestGate: + """Test cases for Gate module""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_gate_creation(self): + """Test Gate creation and basic properties""" + gate = Gate('2x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') + assert isinstance(gate, Gate) + assert gate.irreps_in.dim > 0 + assert gate.irreps_out.dim > 0 + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_gate_forward(self): + """Test forward propagation""" + gate = Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '2x1o') + x = Tensor(np.random.randn(3, gate.irreps_in.dim), dtype=float32) + output = gate(x) + + assert output.shape == (3, gate.irreps_out.dim) + assert not np.isnan(output.asnumpy()).any() + + def test_gate_activations(self): + """Test different activation functions""" + gate1 = Gate('1x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') + gate2 = Gate('1x0e', [ops.relu], '1x0e', [ops.abs], '1x1o') + + x = Tensor(np.random.randn(2, gate1.irreps_in.dim), dtype=float32) + output1, output2 = gate1(x), gate2(x) + + assert output1.shape == output2.shape + assert not np.allclose(output1.asnumpy(), output2.asnumpy(), atol=1e-6) + + def test_gate_errors(self): + """Test error handling""" + with pytest.raises(ValueError, match="Scalars must be scalars"): + Gate('1x1o', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') + + with pytest.raises(ValueError, match="Gate scalars must be scalars"): + Gate('1x0e', [ops.tanh], '1x1o', [ops.sigmoid], '1x1o') + + with pytest.raises(ValueError, match="different number"): + Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '1x1o') diff --git a/tests/e3nn/nn/test_normact.py b/tests/e3nn/nn/test_normact.py new file mode 100644 index 0000000000000000000000000000000000000000..24a34c3f9c67b50812d5acf156e21ffd903abc1a --- /dev/null +++ b/tests/e3nn/nn/test_normact.py @@ -0,0 +1,50 @@ +"""Test cases for NormActivation module""" +import pytest +from mindspore import Tensor, ops, float32 +import numpy as np +from mindscience.e3nn.nn import NormActivation + + +class TestNormActivation: + """Test cases for NormActivation class""" + + def test_creation_and_forward(self): + """Test NormActivation creation and forward pass""" + normact = NormActivation('2x1e', ops.sigmoid) + assert normact.irreps_in.dim > 0 + assert normact.irreps_out.dim == normact.irreps_in.dim + assert normact.normalize is True + assert normact.epsilon == 1e-8 + + x = Tensor(np.random.randn(3, normact.irreps_in.dim), dtype=float32) + output = normact(x) + assert output.shape == x.shape + assert not np.isnan(output.asnumpy()).any() + + def test_normalize_and_epsilon(self): + """Test normalize parameter and epsilon configuration""" + normact_norm = NormActivation('1x1o', ops.sigmoid, normalize=True) + normact_no_norm = NormActivation('1x1o', ops.sigmoid, normalize=False) + normact_eps = NormActivation('1x1o', ops.sigmoid, epsilon=1e-6) + + assert normact_norm.normalize and normact_norm.epsilon == 1e-8 + assert not normact_no_norm.normalize and normact_no_norm.epsilon is None + assert normact_eps.epsilon == 1e-6 and normact_eps.epsilon * normact_eps.epsilon == 1e-12 + + def test_activations_and_bias(self): + """Test different activation functions and bias parameter""" + normact1 = NormActivation('1x1o', ops.sigmoid, bias=True) + normact2 = NormActivation('1x1o', ops.tanh, bias=False) + + x = Tensor(np.random.randn(2, 3), dtype=float32) + output1, output2 = normact1(x), normact2(x) + + assert output1.shape == output2.shape + assert normact1.bias is not None and normact2.bias is None + + def test_errors(self): + """Test error handling for invalid parameter combinations""" + with pytest.raises(ValueError, match="epsilon.*normalize = False.*don't make sense"): + NormActivation('1x1o', ops.sigmoid, normalize=False, epsilon=1e-6) + with pytest.raises(ValueError, match="epsilon.*invalid.*strictly positive"): + NormActivation('1x1o', ops.sigmoid, epsilon=-1e-6) diff --git a/tests/e3nn/nn/test_one_hot.py b/tests/e3nn/nn/test_one_hot.py new file mode 100644 index 0000000000000000000000000000000000000000..81dee981e0dac07c32945f75a28d26429d4a139e --- /dev/null +++ b/tests/e3nn/nn/test_one_hot.py @@ -0,0 +1,151 @@ +"""Test cases for one_hot module""" +import pytest +import numpy as np + +from mindspore import Tensor, ops, float32, int32 +from mindscience.e3nn.nn.one_hot import OneHot, SoftOneHotLinspace, soft_one_hot_linspace, soft_unit_step + + +class TestSoftUnitStep: + """Test soft_unit_step function""" + + def test_soft_unit_step_basic(self): + """Test soft_unit_step with basic functionality""" + # Test positive values + x_pos = Tensor([1.0, 2.0], dtype=float32) + result_pos = soft_unit_step(x_pos) + expected_pos = ops.exp(-1.0 / x_pos) + assert np.allclose(result_pos.asnumpy(), expected_pos.asnumpy(), atol=1e-6) + + # Test negative values (should be zero due to relu) + x_neg = Tensor([-1.0, -2.0], dtype=float32) + result_neg = soft_unit_step(x_neg) + expected_neg = Tensor([0.0, 0.0], dtype=float32) + assert np.allclose(result_neg.asnumpy(), expected_neg.asnumpy(), atol=1e-6) + + # Test zero (may be NaN or 0 due to division by zero) + x_zero = Tensor([0.0], dtype=float32) + result_zero = soft_unit_step(x_zero) + result_np = result_zero.asnumpy() + assert result_np[0] == 0.0 or np.isnan(result_np[0]) + + +class TestOneHot: + """Test OneHot class""" + + def test_onehot_basic(self): + """Test OneHot basic functionality""" + num_types = 4 + onehot = OneHot(num_types) + + # Test creation + assert onehot.num_types == num_types + assert str(onehot.irreps_output) == "4x0e" + + # Test single input + atom_type = Tensor([2], dtype=int32) + result = onehot(atom_type) + expected = Tensor([[0., 0., 1., 0.]], dtype=float32) + assert np.allclose(result.asnumpy(), expected.asnumpy()) + assert result.shape == (1, 4) + + # Test batch input + atom_types = Tensor([0, 1, 2], dtype=int32) + result_batch = onehot(atom_types) + expected_batch = Tensor([ + [1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.] + ], dtype=float32) + assert np.allclose(result_batch.asnumpy(), expected_batch.asnumpy()) + assert result_batch.shape == (3, 4) + + +class TestSoftOneHotLinspace: + """Test SoftOneHotLinspace class""" + + def test_soft_onehot_basic(self): + """Test SoftOneHotLinspace basic functionality""" + start, end, number = 0.0, 2.0, 4 + soft_onehot = SoftOneHotLinspace(start, end, number) + + # Test creation + assert soft_onehot.start.asnumpy() == start + assert soft_onehot.end.asnumpy() == end + assert soft_onehot.number == number + + # Test forward pass + x = Tensor([1.0], dtype=float32) + result = soft_onehot(x) + assert result.shape == (1, 4) + + # Test batch input + x_batch = Tensor([[0.5, 1.0], [1.5, 2.0]], dtype=float32) + result_batch = soft_onehot(x_batch) + assert result_batch.shape == (2, 2, 4) + + def test_soft_onehot_different_basis(self): + """Test SoftOneHotLinspace with different basis functions""" + start, end, number = 0.0, 2.0, 3 + x = Tensor([1.0], dtype=float32) + + for basis in ['gaussian', 'cosine', 'smooth_finite']: + soft_onehot = SoftOneHotLinspace(start, end, number, basis=basis) + result = soft_onehot(x) + assert result.shape == (1, 3) + # Some basis functions may produce NaN at boundaries, which is expected + + def test_soft_onehot_cutoff(self): + """Test SoftOneHotLinspace cutoff behavior""" + start, end, number = 0.0, 2.0, 3 + + # Test with and without cutoff + soft_onehot_cutoff = SoftOneHotLinspace(start, end, number, cutoff=True) + soft_onehot_no_cutoff = SoftOneHotLinspace(start, end, number, cutoff=False) + + x = Tensor([3.0], dtype=float32) # Outside domain + result_cutoff = soft_onehot_cutoff(x) + result_no_cutoff = soft_onehot_no_cutoff(x) + + assert result_cutoff.shape == (1, 3) + assert result_no_cutoff.shape == (1, 3) + + +class TestSoftOneHotLinspaceFunction: + """Test soft_one_hot_linspace function""" + + def test_function_basic(self): + """Test soft_one_hot_linspace function interface""" + x = Tensor([1.0, 1.5, 2.0], dtype=float32) + start, end, number = 0.0, 3.0, 4 + + result = soft_one_hot_linspace(x, start, end, number) + assert result.shape == (3, 4) + + # Test with different basis + result_gaussian = soft_one_hot_linspace(x, start, end, number, basis='gaussian') + assert result_gaussian.shape == (3, 4) + + +class TestEdgeCases: + """Test edge cases and error handling""" + + def test_edge_cases(self): + """Test various edge cases""" + # OneHot with single type + onehot = OneHot(1) + atom_type = Tensor([0], dtype=int32) + result = onehot(atom_type) + assert result.shape == (1, 1) + assert np.allclose(result.asnumpy(), Tensor([[1.0]], dtype=float32).asnumpy()) + + # SoftOneHotLinspace with small number + soft_onehot = SoftOneHotLinspace(0.0, 1.0, 2) + x = Tensor([0.5], dtype=float32) + result = soft_onehot(x) + assert result.shape == (1, 2) + + # Invalid basis should raise error + soft_onehot_invalid = SoftOneHotLinspace(0.0, 1.0, 3, basis='invalid') + with pytest.raises(ValueError, match="Unsupported basis"): + soft_onehot_invalid(x) diff --git a/tests/e3nn/nn/test_scatter.py b/tests/e3nn/nn/test_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..5c21243ce40227c163796d53cff6b219cc06eba4 --- /dev/null +++ b/tests/e3nn/nn/test_scatter.py @@ -0,0 +1,64 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test scatter""" +import numpy as np +import pytest +from mindspore import Tensor, float32, int32 +from mindscience.e3nn.nn import Scatter + + +class TestScatter: + """Test Scatter class core functionality""" + + def test_scatter_add(self): + """Test scatter add operation""" + scatter = Scatter(mode='add') + + src = Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=float32) + index = Tensor([0, 1], dtype=int32) + + result = scatter(src, index, dim_size=2) + expected = Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=float32) + + assert np.allclose(result.asnumpy(), expected.asnumpy()) + + def test_scatter_max(self): + """Test scatter max operation""" + scatter = Scatter(mode='max') + + src = Tensor([[1.0, 5.0], [3.0, 2.0], [2.0, 4.0]], dtype=float32) + index = Tensor([0, 1, 0], dtype=int32) + + result = scatter(src, index, dim_size=2) + expected = Tensor([[2.0, 5.0], [3.0, 2.0]], dtype=float32) + + assert np.allclose(result.asnumpy(), expected.asnumpy()) + + def test_scatter_with_out_parameter(self): + """Test scatter with out parameter for proper initialization""" + scatter = Scatter(mode='mul') + + src = Tensor([[2.0, 3.0], [4.0, 5.0]], dtype=float32) + index = Tensor([0, 1], dtype=int32) + out = Tensor([[1.0, 1.0], [1.0, 1.0]], dtype=float32) + + result = scatter(src, index, out=out) + expected = Tensor([[2.0, 3.0], [4.0, 5.0]], dtype=float32) + + assert np.allclose(result.asnumpy(), expected.asnumpy()) + + def test_scatter_invalid_mode(self): + """Test scatter with invalid mode""" + with pytest.raises(ValueError, match="Unexpected scatter mode"): + Scatter(mode='invalid') diff --git a/tests/e3nn/o3/test_irreps.py b/tests/e3nn/o3/test_irreps.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe93092c2b1c4c31931b9642ea2b88be50325df --- /dev/null +++ b/tests/e3nn/o3/test_irreps.py @@ -0,0 +1,253 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for irreps module""" + +import pytest +from mindspore import ops +from mindscience.e3nn.o3 import Irrep, Irreps + + +class TestIrrep: + """Test cases for Irrep class""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_irrep_creation_and_properties(self): + """Test Irrep creation, properties and basic operations""" + # Test creation with l and p parameters + irrep1 = Irrep(0, 1) + assert irrep1.l == 0 + assert irrep1.p == 1 + assert str(irrep1) == "0e" + assert irrep1.dim == 1 + assert irrep1.is_scalar() is True + + irrep2 = Irrep(1, -1) + assert irrep2.l == 1 + assert irrep2.p == -1 + assert str(irrep2) == "1o" + assert irrep2.dim == 3 + assert irrep2.is_scalar() is False + + # Test creation with string notation + irrep3 = Irrep("2e") + assert irrep3.l == 2 + assert irrep3.p == 1 + assert irrep3.dim == 5 + + irrep4 = Irrep("3y") + assert irrep4.l == 3 + assert irrep4.p == -1 # (-1)^3 = -1 + + # Test comparison operations + assert irrep1 == Irrep(0, 1) + assert irrep1 != irrep2 + assert irrep1 < irrep2 # Compare by l first, then p + + def test_irrep_multiplication_and_arithmetic(self): + """Test Irrep multiplication and arithmetic operations""" + irrep1 = Irrep(1, 1) + irrep2 = Irrep(1, 1) + + # Test tensor product + products = list(irrep1 * irrep2) + expected = [Irrep(0, 1), Irrep(1, 1), Irrep(2, 1)] + assert products == expected + + # Test with different parities + irrep3 = Irrep(1, -1) + products2 = list(irrep1 * irrep3) + expected2 = [Irrep(0, -1), Irrep(1, -1), Irrep(2, -1)] + assert products2 == expected2 + + # Test arithmetic operations + result = 3 * irrep1 + assert isinstance(result, Irreps) + assert result.data[0].mul == 3 + assert result.data[0].ir == irrep1 + + result_add = irrep1 + irrep3 + assert isinstance(result_add, Irreps) + assert len(result_add) == 2 + + def test_irrep_error_handling_and_wigner(self): + """Test Irrep error handling and Wigner D matrix""" + # Test error handling + with pytest.raises(ValueError): + Irrep(-1, 1) # Negative l + + with pytest.raises(ValueError): + Irrep(1, 2) # Invalid parity + + with pytest.raises(ValueError): + Irrep("invalid") + + # Test Wigner D matrix + irrep = Irrep(1, -1) + rotation_matrix = ops.eye(3) + d_matrix = irrep.wigD_from_matrix(rotation_matrix) + assert d_matrix.shape == (3, 3) + + # Test error for non-tensor input + with pytest.raises(TypeError): + irrep.wigD_from_matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + +class TestIrreps: + """Test cases for Irreps class""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_irreps_creation_and_basic_operations(self): + """Test Irreps creation and basic operations""" + # Test creation from string + irreps1 = Irreps("1x0e+2x1o") + assert len(irreps1) == 2 + assert irreps1.data[0].mul == 1 + assert irreps1.data[0].ir == Irrep(0, 1) + assert str(irreps1) == "1x0e+2x1o" + + # Test creation from list of tuples + irreps2 = Irreps([(1, (0, 1)), (2, (1, -1))]) + assert irreps1 == irreps2 + + # Test creation from single Irrep + irreps3 = Irreps(Irrep(1, 1)) + assert len(irreps3) == 1 + + # Test empty creation + irreps4 = Irreps() + assert not irreps4 # Check if empty + assert irreps4.dim == 0 + + # Test single irrep without multiplicity + irreps_single = Irreps("1o") + assert irreps_single.data[0].mul == 1 + assert irreps_single.data[0].ir == Irrep(1, -1) + + def test_irreps_properties_and_slicing(self): + """Test Irreps properties and slicing operations""" + irreps = Irreps("2x0e+3x1o+1x2e") + + # Test dimension + expected_dim = 2 * 1 + 3 * 3 + 1 * 5 # 2 + 9 + 5 = 16 + assert irreps.dim == expected_dim + + # Test slices + assert len(irreps.slice) == 3 + assert irreps.slice[0] == slice(0, 2) + assert irreps.slice[1] == slice(2, 11) + assert irreps.slice[2] == slice(11, 16) + + # Test lmax and num_irreps properties + assert irreps.lmax == 2 + assert irreps.num_irreps == 6 # 2 + 3 + 1 + + # Test contains operation + assert Irrep(0, 1) in irreps + assert Irrep(3, 1) not in irreps + + def test_irreps_arithmetic_and_operations(self): + """Test Irreps arithmetic operations and advanced features""" + irreps1 = Irreps("1x0e+1x1o") + irreps2 = Irreps("2x0e+1x2e") + + # Test addition + result_add = irreps1 + irreps2 + assert len(result_add) == 4 + + # Test multiplication with integer + result_mul = irreps1 * 2 + expected_mul = Irreps("2x0e+2x1o") + assert result_mul == expected_mul + + # Test comparison operations + assert irreps1 == Irreps("1x0e+1x1o") + assert irreps1 != irreps2 + + # Test iteration + for i, (mul, ir) in enumerate(irreps1): + if i == 0: + assert mul == 1 and ir == Irrep(0, 1) + elif i == 1: + assert mul == 1 and ir == Irrep(1, -1) + + def test_irreps_error_handling_and_edge_cases(self): + """Test Irreps error handling and edge cases""" + # Test invalid string format + with pytest.raises(ValueError): + Irreps("invalid_format") + + # Test negative multiplicity + with pytest.raises(ValueError): + Irreps([(-1, (0, 1))]) + + # Test invalid multiplicity type + with pytest.raises(ValueError): + Irreps([(1.5, (0, 1))]) + + # Test empty Irreps lmax property + irreps_empty = Irreps("") + with pytest.raises(ValueError): + _ = irreps_empty.lmax + + # Test zero multiplicity + zero_irreps = Irreps("0x1o+2x0e") + assert len(zero_irreps) == 2 + assert zero_irreps.data[0].mul == 0 + + # Test large irreps + large_irreps = Irreps("100x0e+50x1e") + assert large_irreps.dim == 100 * 1 + 50 * 3 + assert len(large_irreps) == 2 + + +class TestMulIr: + """Test cases for _MulIr class""" + + def test_mulir_comprehensive(self): + """Test _MulIr creation, properties and operations""" + from mindscience.e3nn.o3.irreps import _MulIr + + # Test creation and properties + irrep = Irrep(1, 1) + mulir = _MulIr(3, irrep) + assert mulir.mul == 3 + assert mulir.ir == irrep + assert mulir.dim == 3 * 3 # mul * irrep.dim + assert str(mulir) == "3x1e" + + # Test iteration/deconstruction + mul, ir = mulir + assert mul == 3 + assert ir == irrep + + # Test comparison operations + mulir2 = _MulIr(3, irrep) + mulir3 = _MulIr(2, irrep) + mulir4 = _MulIr(3, Irrep(2, 1)) + + assert mulir == mulir2 + assert mulir != mulir3 + assert mulir < mulir4 # Compare by irrep first + + # Test error handling + with pytest.raises(TypeError): + _MulIr(1.5, irrep) # mul should be int + + with pytest.raises(TypeError): + _MulIr(2, "1e") # ir should be Irrep instance diff --git a/tests/e3nn/o3/test_norm.py b/tests/e3nn/o3/test_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..84118cbdb669090789b96e716aee5a2401e2c3ef --- /dev/null +++ b/tests/e3nn/o3/test_norm.py @@ -0,0 +1,110 @@ +"""Test cases for e3nn.o3.norm module - Streamlined core functionality""" +import pytest +import numpy as np +from mindspore import Tensor, float32 + +from mindscience.e3nn.o3 import Norm, Irreps + + +class TestNorm: + """Streamlined tests for Norm class""" + + def test_norm_creation_and_basic_properties(self): + """Test Norm creation with different irreps and basic properties""" + # Test basic creation with string irreps + norm1 = Norm('1x0e') + assert norm1.irreps_in == Irreps('1x0e') + assert norm1.irreps_out == Irreps('1x0e') + assert not norm1.squared + + # Test creation with Irreps object and squared parameter + irreps_in = Irreps('2x1o + 3x0e') + norm2 = Norm(irreps_in, squared=True) + assert norm2.irreps_in == irreps_in.simplify() + assert norm2.irreps_out == Irreps('2x0e + 3x0e').simplify() + assert norm2.squared + + # Test string representation + repr_str = repr(norm2) + assert 'Norm' in repr_str + + def test_norm_forward_pass_comprehensive(self): + """Test forward pass with various irrep types and configurations""" + # Test scalar irrep (0e) + norm_scalar = Norm('2x0e') + scalar_input = Tensor([1.0, -2.0], dtype=float32) + scalar_output = norm_scalar(scalar_input) + np.testing.assert_allclose(scalar_output.asnumpy(), [1.0, 2.0], rtol=1e-5) + + # Test vector irrep (1o) with batch processing + norm_vector = Norm('1x1o') + vector_batch = Tensor([[3.0, 4.0, 0.0], [0.0, 0.0, 0.0]], dtype=float32) + vector_output = norm_vector(vector_batch) + expected = np.array([[5.0], [0.0]]) + np.testing.assert_allclose(vector_output.asnumpy(), expected, rtol=1e-5) + + # Test mixed irreps + norm_mixed = Norm('1x0e + 1x1o') + mixed_input = Tensor([2.0, 3.0, 4.0, 0.0], dtype=float32) + mixed_output = norm_mixed(mixed_input) + expected_mixed = np.array([2.0, 5.0]) # scalar norm + vector norm + np.testing.assert_allclose(mixed_output.asnumpy(), expected_mixed, rtol=1e-5) + + def test_norm_squared_and_dtype_consistency(self): + """Test squared parameter and dtype consistency""" + # Test squared vs regular norm + norm_regular = Norm('1x1o', squared=False, dtype=float32) + norm_squared = Norm('1x1o', squared=True, dtype=float32) + + input_vec = Tensor([3.0, 4.0, 0.0], dtype=float32) + output_regular = norm_regular(input_vec) + output_squared = norm_squared(input_vec) + + # Verify squared relationship and dtype consistency + np.testing.assert_allclose(output_regular.asnumpy(), [5.0], rtol=1e-5) + np.testing.assert_allclose(output_squared.asnumpy(), [25.0], rtol=1e-5) + assert output_regular.dtype == float32 + assert output_squared.dtype == float32 + + def test_norm_mathematical_properties_and_edge_cases(self): + """Test mathematical properties and edge cases""" + norm = Norm('1x1o') + + # Test scaling property: ||k*v|| = |k| * ||v|| + vector = Tensor([3.0, 4.0, 0.0], dtype=float32) + scaled_vector = Tensor([6.0, 8.0, 0.0], dtype=float32) + + norm_original = norm(vector) + norm_scaled = norm(scaled_vector) + np.testing.assert_allclose(norm_scaled.asnumpy(), 2.0 * norm_original.asnumpy(), rtol=1e-5) + + # Test with very small values + small_input = Tensor([1e-10, 1e-10, 1e-10], dtype=float32) + small_output = norm(small_input) + expected_small = np.sqrt(3) * 1e-10 + np.testing.assert_allclose(small_output.asnumpy(), [expected_small], rtol=1e-5) + + def test_norm_higher_order_and_mixed_parity(self): + """Test higher order irreps and mixed parity""" + # Test l=2 irrep + norm_l2 = Norm('1x2e') + l2_input = Tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=float32) + l2_output = norm_l2(l2_input) + expected_l2 = np.sqrt(5.0) + np.testing.assert_allclose(l2_output.asnumpy(), [expected_l2], rtol=1e-5) + + # Test mixed parity + norm_mixed_parity = Norm('1x0e + 1x1o + 1x0o') + mixed_parity_input = Tensor([2.0, 1.0, 1.0, 1.0, 3.0], dtype=float32) + mixed_parity_output = norm_mixed_parity(mixed_parity_input) + expected_mixed_parity = np.array([2.0, np.sqrt(3.0), 3.0]) + np.testing.assert_allclose(mixed_parity_output.asnumpy(), expected_mixed_parity, rtol=1e-5) + + def test_norm_error_handling(self): + """Test error handling for invalid inputs""" + norm = Norm('1x1o') + + # Test with wrong input dimension + with pytest.raises((ValueError, RuntimeError)): + wrong_dim_input = Tensor([1.0, 2.0], dtype=float32) # Should be 3D for 1x1o + norm(wrong_dim_input) diff --git a/tests/e3nn/o3/test_rotation.py b/tests/e3nn/o3/test_rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..53ed0c4105cb221348e0da45991d82e4fe6b3b5c --- /dev/null +++ b/tests/e3nn/o3/test_rotation.py @@ -0,0 +1,185 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for rotation module.""" + +import math +import pytest +import numpy as np +from mindspore import Tensor, float32 + +from mindscience.e3nn.o3.rotation import ( + identity_angles, rand_angles, compose_angles, + matrix_x, matrix_y, matrix_z, + angles_to_matrix, matrix_to_angles, + angles_to_xyz, xyz_to_angles +) + + +class TestRotation: + """Test class for rotation functions.""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_identity_angles(self): + """Test identity_angles function comprehensively.""" + # Test basic functionality and shapes + alpha, beta, gamma = identity_angles(2, 3) + assert alpha.shape == (2, 3) + assert beta.shape == (2, 3) + assert gamma.shape == (2, 3) + assert np.allclose(alpha.asnumpy(), 0.0) + assert np.allclose(beta.asnumpy(), 0.0) + assert np.allclose(gamma.asnumpy(), 0.0) + + # Test dtype + alpha, beta, gamma = identity_angles(2, dtype=float32) + assert alpha.dtype == float32 + + # Test error handling + with pytest.raises(TypeError): + identity_angles(1.5) # Should be int + + def test_rand_angles(self): + """Test rand_angles function comprehensively.""" + # Test shapes and angle ranges + alpha, beta, gamma = rand_angles(2, 3) + assert alpha.shape == (2, 3) + assert beta.shape == (2, 3) + assert gamma.shape == (2, 3) + assert np.all(alpha.asnumpy() >= 0) and np.all(alpha.asnumpy() <= 2 * math.pi) + assert np.all(beta.asnumpy() >= 0) and np.all(beta.asnumpy() <= math.pi) + assert np.all(gamma.asnumpy() >= 0) and np.all(gamma.asnumpy() <= 2 * math.pi) + + # Test error handling + with pytest.raises(TypeError): + rand_angles(1.5) # Should be int + + def test_rotation_matrices(self): + """Test rotation matrix functions (matrix_x, matrix_y, matrix_z).""" + # Test identity matrices with zero angle + for matrix_func in [matrix_x, matrix_y, matrix_z]: + mat = matrix_func(0.0) + assert np.allclose(mat.asnumpy(), np.eye(3), atol=1e-6) + + # Test specific rotations + mat_x = matrix_x(math.pi / 2) + expected_x = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + assert np.allclose(mat_x.asnumpy(), expected_x, atol=1e-6) + + mat_y = matrix_y(math.pi / 2) + expected_y = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]]) + assert np.allclose(mat_y.asnumpy(), expected_y, atol=1e-6) + + mat_z = matrix_z(math.pi / 2) + expected_z = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) + assert np.allclose(mat_z.asnumpy(), expected_z, atol=1e-6) + + # Test batch operations + angles = Tensor([0.1, 0.2, 0.3]) + mat = matrix_x(angles) + assert mat.shape == (3, 3, 3) + + def test_rotation_matrices_orthogonal(self): + """Test that rotation matrices are orthogonal.""" + angle = 0.5 + for matrix_func in [matrix_x, matrix_y, matrix_z]: + mat = matrix_func(angle) + # Check orthogonality: R @ R.T = I + identity = np.matmul(mat.asnumpy(), mat.asnumpy().T) + assert np.allclose(identity, np.eye(3), atol=1e-6) + # Check determinant = 1 + assert np.allclose(np.linalg.det(mat.asnumpy()), 1.0, atol=1e-6) + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_angle_matrix_conversion(self): + """Test angles_to_matrix and matrix_to_angles functions.""" + # Test identity conversion + mat = angles_to_matrix(0.0, 0.0, 0.0) + assert np.allclose(mat.asnumpy(), np.eye(3), atol=1e-6) + + # Test roundtrip conversion + alpha_orig = Tensor([0.1, 0.2, 0.3]) + beta_orig = Tensor([0.4, 0.5, 0.6]) + gamma_orig = Tensor([0.7, 0.8, 0.9]) + + mat = angles_to_matrix(alpha_orig, beta_orig, gamma_orig) + assert mat.shape == (3, 3, 3) + alpha_new, beta_new, gamma_new = matrix_to_angles(mat) + assert np.allclose(alpha_orig.asnumpy(), alpha_new.asnumpy(), atol=1e-5) + assert np.allclose(beta_orig.asnumpy(), beta_new.asnumpy(), atol=1e-5) + assert np.allclose(gamma_orig.asnumpy(), gamma_new.asnumpy(), atol=1e-5) + + def test_angles_matrix_roundtrip(self): + """Test roundtrip conversion between angles and matrix.""" + # Test multiple angle sets + test_angles = [ + (0.1, 0.2, 0.3), + (0.4, 0.5, 0.6), + (1.0, 1.5, 2.0), + (math.pi/4, math.pi/3, math.pi/6) + ] + + for alpha, beta, gamma in test_angles: + # Convert angles to matrix and back + mat = angles_to_matrix(alpha, beta, gamma) + alpha_rec, beta_rec, gamma_rec = matrix_to_angles(mat) + + # Check if we get back the same angles (within tolerance) + # Note: Euler angles may have multiple representations + mat_rec = angles_to_matrix(alpha_rec, beta_rec, gamma_rec) + assert np.allclose(mat.asnumpy(), mat_rec.asnumpy(), atol=1e-5) + + def test_matrix_to_angles_error(self): + """Test matrix_to_angles error handling.""" + # Test with non-rotation matrix (determinant != 1) + invalid_matrix = Tensor(np.array([[2, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)) + with pytest.raises(ValueError): + matrix_to_angles(invalid_matrix) + + def test_angle_operations(self): + """Test compose_angles, angles_to_xyz, and xyz_to_angles functions.""" + # Test compose_angles with identity + alpha_comp, beta_comp, gamma_comp = compose_angles(0.0, 0.0, 0.0, 0.1, 0.2, 0.3) + assert np.allclose(alpha_comp.asnumpy(), 0.1, atol=1e-6) + assert np.allclose(beta_comp.asnumpy(), 0.2, atol=1e-6) + assert np.allclose(gamma_comp.asnumpy(), 0.3, atol=1e-6) + + # Test angles_to_xyz and xyz_to_angles roundtrip + xyz = angles_to_xyz(0.0, 0.0) + assert np.allclose(xyz.asnumpy(), [0.0, 1.0, 0.0], atol=1e-6) + alpha, beta = xyz_to_angles(xyz) + assert np.allclose(alpha.asnumpy(), 0.0, atol=1e-6) + assert np.allclose(beta.asnumpy(), 0.0, atol=1e-6) + + def test_batch_and_edge_cases(self): + """Test batch operations and edge cases.""" + # Test batch operations + alphas = Tensor(np.array([[0.1, 0.2], [0.3, 0.4]]).astype(np.float32)) + betas = Tensor(np.array([[0.5, 0.6], [0.7, 0.8]]).astype(np.float32)) + gammas = Tensor(np.array([[0.9, 1.0], [1.1, 1.2]]).astype(np.float32)) + matrices = angles_to_matrix(alphas, betas, gammas) + assert matrices.shape == (2, 2, 3, 3) + # Test edge case: small angles + mat = angles_to_matrix(1e-8, 1e-8, 1e-8) + assert np.allclose(mat.asnumpy(), np.eye(3), atol=1e-6) + + # Test edge case: pi angles (should still be valid rotation matrix) + mat = angles_to_matrix(math.pi, math.pi, math.pi) + identity = np.matmul(mat.asnumpy(), mat.asnumpy().T) + assert np.allclose(identity, np.eye(3), atol=1e-5) + assert np.allclose(np.linalg.det(mat.asnumpy()), 1.0, atol=1e-5) diff --git a/tests/e3nn/o3/test_spherical_harmonics.py b/tests/e3nn/o3/test_spherical_harmonics.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a78940fe7825b3d2803f7e54fc8042a5ccc30e --- /dev/null +++ b/tests/e3nn/o3/test_spherical_harmonics.py @@ -0,0 +1,180 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test spherical harmonics module.""" + +import pytest +import numpy as np +from mindspore import Tensor, float32 +from mindscience.e3nn.o3 import spherical_harmonics, SphericalHarmonics + + +class TestSphericalHarmonicsFunction: + """Test spherical_harmonics function.""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_core_functionality(self): + """Test core spherical harmonics functionality including degrees and normalization.""" + x = Tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], float32) + + # Test l=0 (constant function) + result_l0 = spherical_harmonics(0, x) + assert result_l0.shape == (2, 1) + np.testing.assert_allclose(result_l0.asnumpy(), [[0.28209479], [0.28209479]], rtol=1e-5) + + # Test different degrees + result_l1 = spherical_harmonics(1, x[:1]) + assert result_l1.shape == (1, 3) + result_l2 = spherical_harmonics(2, x[:1]) + assert result_l2.shape == (1, 5) + + # Test multiple degrees + result_multi = spherical_harmonics([0, 1, 2], x[:1]) + assert result_multi.shape == (1, 9) # 1 + 3 + 5 + + def test_normalization_and_parameters(self): + """Test normalization methods and normalize parameter.""" + x = Tensor([[1.0, 0.0, 0.0]], float32) + x_unnorm = Tensor([[2.0, 0.0, 0.0]], float32) + + # Test different normalization methods + result_integral = spherical_harmonics(1, x, normalization='integral') + result_component = spherical_harmonics(1, x, normalization='component') + result_norm = spherical_harmonics(1, x, normalization='norm') + + # Results should be different for different normalizations + assert not np.allclose(result_integral.asnumpy(), result_component.asnumpy()) + assert not np.allclose(result_integral.asnumpy(), result_norm.asnumpy()) + + # Test normalize parameter + result_normalized = spherical_harmonics(1, x_unnorm, normalize=True) + result_unnormalized = spherical_harmonics(1, x_unnorm, normalize=False) + assert not np.allclose(result_normalized.asnumpy(), result_unnormalized.asnumpy()) + + def test_batch_and_shapes(self): + """Test batch processing and different input shapes.""" + # Multiple vectors + x = Tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], float32) + result = spherical_harmonics(2, x) + assert result.shape == (3, 5) + + # Higher dimensional batch + x_batch = Tensor(np.random.randn(2, 3, 3).astype(np.float32)) + result_batch = spherical_harmonics(1, x_batch) + assert result_batch.shape == (2, 3, 3) + + # 3D input + x_3d = Tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], float32) + result_3d = spherical_harmonics(1, x_3d) + assert result_3d.shape == (1, 2, 3) + + +class TestSphericalHarmonicsClass: + """Test the SphericalHarmonics class.""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_class_initialization_and_forward(self): + """Test class initialization and forward computation.""" + # Test initialization + sh = SphericalHarmonics(2, normalize=True) + # Verify the output dimension instead of accessing protected member + assert sh.irreps_out.dim == 5 + # Verify the irreps_out contains the expected l=2 representation + assert str(sh.irreps_out) == "1x2e" + + # Test forward computation + x = Tensor([[1.0, 0.0, 0.0]], float32) + result = sh(x) + assert result.shape == (1, 5) + + # Compare with function version + result_func = spherical_harmonics(2, x) + np.testing.assert_allclose(result.asnumpy(), result_func.asnumpy(), rtol=1e-5) + + def test_consistency_and_parity(self): + """Test normalization consistency and parity.""" + x = Tensor([[1.0, 0.0, 0.0]], float32) + + # Test normalization consistency + sh_integral = SphericalHarmonics(1, normalize=True, normalization='integral') + sh_component = SphericalHarmonics(1, normalize=True, normalization='component') + result_integral = sh_integral(x) + result_component = sh_component(x) + assert not np.allclose(result_integral.asnumpy(), result_component.asnumpy()) + + # Test parity consistency + sh = SphericalHarmonics(2, normalize=True) + x_pos = Tensor([[1.0, 0.0, 0.0]], float32) + x_neg = Tensor([[-1.0, 0.0, 0.0]], float32) + result_pos = sh(x_pos) + result_neg = sh(x_neg) + # For even l, parity should be preserved + np.testing.assert_allclose(result_pos.asnumpy(), result_neg.asnumpy(), rtol=1e-5) + + +class TestMathematicalProperties: + """Test mathematical properties of spherical harmonics.""" + + def test_mathematical_properties(self): + """Test basic mathematical properties and rotation equivariance.""" + # Test basic properties for l=1 + x = Tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], float32) + result = spherical_harmonics(1, x) + + # Test that results are finite and have correct shape + assert result.shape == (3, 3) + assert np.all(np.isfinite(result.asnumpy())) + + # Test rotation equivariance (simplified) + x_original = Tensor([[1.0, 0.0, 0.0]], float32) + x_rotated = Tensor([[0.0, 1.0, 0.0]], float32) # 90° rotation around z + sh_original = spherical_harmonics(1, x_original) + sh_rotated = spherical_harmonics(1, x_rotated) + # Results should be different for different orientations + assert sh_original.shape == sh_rotated.shape + assert not np.allclose(sh_original.asnumpy(), sh_rotated.asnumpy()) + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_edge_cases(self): + """Test zero vectors, high degrees, and error conditions.""" + # Test zero vector + x_zero = Tensor([[0.0, 0.0, 0.0]], float32) + result_zero = spherical_harmonics(1, x_zero) + assert result_zero.shape == (1, 3) + + # Test high degree + x = Tensor([[1.0, 0.0, 0.0]], float32) + result_high = spherical_harmonics(5, x) + assert result_high.shape == (1, 11) # 2*5+1 + + # Test invalid degree (should raise error) + try: + spherical_harmonics(-1, x) + assert False, "Should raise error for negative degree" + except (ValueError, TypeError): + pass + + # Test invalid normalization + try: + spherical_harmonics(1, x, normalization='invalid') + assert False, "Should raise error for invalid normalization" + except (ValueError, TypeError): + pass diff --git a/tests/e3nn/o3/test_sub.py b/tests/e3nn/o3/test_sub.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8d1aabb6532370a2a403cbc65c231d694f344e --- /dev/null +++ b/tests/e3nn/o3/test_sub.py @@ -0,0 +1,142 @@ +# Copyright 2021-2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test cases for o3.sub module. + +This module contains comprehensive tests for all classes and functions in the o3.sub module, +including tensor product operations, linear operations, and utility functions. +""" + +import pytest +import numpy as np +import mindspore as ms +from mindspore import Tensor + +from mindscience.e3nn.o3.sub import ( + FullyConnectedTensorProduct, + FullTensorProduct, + ElementwiseTensorProduct, + Linear, + LinearBias, + TensorSquare, + prod, + _prod, + _sum_tensors_withbias, + Instruction +) + + +class TestTensorProductClasses: + """Test tensor product classes functionality.""" + + def test_tensor_product_operations(self): + """Test core tensor product operations.""" + # Test FullyConnectedTensorProduct + tp_fc = FullyConnectedTensorProduct('1x1o', '1x0e', '1x1o') + x1 = Tensor(np.random.randn(2, 3), ms.float32) + x2 = Tensor(np.random.randn(2, 1), ms.float32) + output_fc = tp_fc(x1, x2) + assert output_fc.shape == (2, 3) + + # Test FullTensorProduct + tp_full = FullTensorProduct('1x1o', '1x0e') + output_full = tp_full(x1, x2) + assert output_full.ndim == 2 + + # Test ElementwiseTensorProduct + tp_elem = ElementwiseTensorProduct('1x1o', '1x1o') + output_elem = tp_elem(x1, x1) + assert output_elem.ndim == 2 + + def test_linear_operations(self): + """Test linear operations with and without bias.""" + # Test Linear + linear = Linear('1x1o+1x0e', '1x1o') + x = Tensor(np.random.randn(2, 4), ms.float32) + output = linear(x) + assert output.shape == (2, 3) + + # Test LinearBias + linear_bias = LinearBias('1x1o+1x0e', '1x1o+1x0e', has_bias=True) + output_bias = linear_bias(x) + assert output_bias.shape == (2, 4) + + def test_tensor_square(self): + """Test TensorSquare operation.""" + ts = TensorSquare('1x1o', irreps_out='1x0e+1x2e') + x = Tensor(np.random.randn(2, 3), ms.float32) + output = ts(x) + assert output.shape == (2, 6) # 1x0e+1x2e has dim 6 + + +class TestUtilityFunctions: + """Test utility functions.""" + + def test_prod_functions(self): + """Test product computation functions.""" + # Test prod function + assert prod([2, 3, 4]) == 24 + assert prod([]) == 1 + + # Test _prod function + assert _prod((2, 3, 4)) == 24 + assert _prod(()) == 1 + + def test_tensor_utilities(self): + """Test tensor utility functions.""" + # Test _sum_tensors_withbias + t1 = Tensor(np.array([1, 2, 3]), ms.float32) + t2 = Tensor(np.array([4, 5, 6]), ms.float32) + + result = _sum_tensors_withbias([t1, t2], (3,), ms.float32) + expected = np.array([5, 7, 9]) + assert np.allclose(result.asnumpy(), expected) + + # Test Instruction NamedTuple + instr = Instruction(i_in=0, i_out=1, path_shape=(2, 3), path_weight=1.5) + assert instr.i_in == 0 and instr.i_out == 1 + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_error_handling(self): + """Test error handling for invalid inputs.""" + # Test invalid irreps + with pytest.raises((ValueError, TypeError)): + FullyConnectedTensorProduct('invalid', '1x0e', '1x0e') + + # Test mismatched dimensions + tp = FullyConnectedTensorProduct('1x0e', '1x0e', '1x0e') + x1 = Tensor(np.random.randn(2, 5), ms.float32) # Wrong dimension + x2 = Tensor(np.random.randn(2, 1), ms.float32) + + with pytest.raises(ValueError): + tp(x1, x2) + + def test_scalar_operations(self): + """Test operations with scalar irreps.""" + linear = Linear('1x0e', '1x0e') + x = Tensor(np.random.randn(2, 1), ms.float32) + output = linear(x) + assert output.shape == (2, 1) diff --git a/tests/e3nn/o3/test_tensor_product.py b/tests/e3nn/o3/test_tensor_product.py new file mode 100644 index 0000000000000000000000000000000000000000..e5793ee51d80f97add7c2b594c69312385250b5d --- /dev/null +++ b/tests/e3nn/o3/test_tensor_product.py @@ -0,0 +1,122 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for tensor_product module.""" + +import pytest +import numpy as np +from mindspore import Tensor, float32 + +from mindscience.e3nn.o3.tensor_product import TensorProduct +from mindscience.e3nn.o3.sub import ( + FullTensorProduct, FullyConnectedTensorProduct, + ElementwiseTensorProduct, TensorSquare, Linear +) +from mindscience.e3nn.o3.irreps import Irreps + + +class TestTensorProduct: + """Test class for TensorProduct and related classes.""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_tensor_product_basic(self): + """Test basic TensorProduct functionality.""" + # Test standard tensor product + tp = TensorProduct('2x1o+1x0e', '1x1o+1x0e') + assert tp.irreps_in1.dim == 7 # 2*3 + 1*1 = 7 + assert tp.irreps_in2.dim == 4 # 1*3 + 1*1 = 4 + + # Test with input tensors + x1 = Tensor(np.random.randn(2, tp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(2, tp.irreps_in2.dim), dtype=float32) + output = tp(x1, x2) + assert output.shape == (2, tp.irreps_out.dim) + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_full_tensor_product(self): + """Test FullTensorProduct functionality.""" + # Test full tensor product + ftp = FullTensorProduct('1x1o+1x0e', '1x1o+1x0e') + x1 = Tensor(np.random.randn(2, ftp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(2, ftp.irreps_in2.dim), dtype=float32) + output = ftp(x1, x2) + assert output.shape == (2, ftp.irreps_out.dim) + + def test_fully_connected_tensor_product(self): + """Test FullyConnectedTensorProduct functionality.""" + # Test fully connected tensor product + fctp = FullyConnectedTensorProduct('1x1o', '1x1o', '1x2e+1x0e') + x1 = Tensor(np.random.randn(2, fctp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(2, fctp.irreps_in2.dim), dtype=float32) + output = fctp(x1, x2) + assert output.shape == (2, fctp.irreps_out.dim) + assert fctp.weight_numel > 0 # Should have learnable weights + + def test_elementwise_tensor_product(self): + """Test ElementwiseTensorProduct functionality.""" + # Test elementwise tensor product + etp = ElementwiseTensorProduct('2x1o+1x0e', '2x1o+1x0e') + x1 = Tensor(np.random.randn(2, etp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(2, etp.irreps_in2.dim), dtype=float32) + output = etp(x1, x2) + assert output.shape == (2, etp.irreps_out.dim) + + def test_tensor_square(self): + """Test TensorSquare functionality.""" + # Test tensor square without output specification + ts = TensorSquare('1x1o+1x0e') + x = Tensor(np.random.randn(2, ts.irreps_in1.dim), dtype=float32) + output = ts(x) + assert output.shape == (2, ts.irreps_out.dim) + + # Test tensor square with output specification + ts_out = TensorSquare('1x1o', irreps_out='1x2e+1x0e') + x = Tensor(np.random.randn(2, ts_out.irreps_in1.dim), dtype=float32) + output = ts_out(x) + assert output.shape == (2, ts_out.irreps_out.dim) + assert ts_out.weight_numel > 0 # Should have learnable weights + + def test_linear_operation(self): + """Test Linear operation functionality.""" + # Test linear operation + linear = Linear('1x1o+1x0e', '2x1o+1x0e') + x = Tensor(np.random.randn(2, linear.irreps_in1.dim), dtype=float32) + output = linear(x) + assert output.shape == (2, linear.irreps_out.dim) + assert linear.weight_numel > 0 # Should have learnable weights + + def test_tensor_product_properties(self): + """Test tensor product properties and edge cases.""" + # Test properties + tp = TensorProduct('1x1o', '1x1o', '1x2e+1x0e', instructions='connect') + assert isinstance(tp.irreps_in1, Irreps) + assert isinstance(tp.irreps_in2, Irreps) + assert isinstance(tp.irreps_out, Irreps) + assert isinstance(tp.instructions, list) + assert tp.weight_numel >= 0 + + # Test string representation + repr_str = repr(tp) + assert 'TensorProduct' in repr_str + assert 'connect' in repr_str + + # Test with single batch + x1 = Tensor(np.random.randn(tp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(tp.irreps_in2.dim), dtype=float32) + output = tp(x1, x2) + assert output.shape == (tp.irreps_out.dim,) diff --git a/tests/e3nn/o3/test_wigner.py b/tests/e3nn/o3/test_wigner.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d0b5f4576a1c6b2816a0366fa152f07cbf61a7 --- /dev/null +++ b/tests/e3nn/o3/test_wigner.py @@ -0,0 +1,119 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for o3.wigner module.""" + +import pytest +import numpy as np +from mindspore import float32, float64, complex64, complex128 + +from mindscience.e3nn.o3.wigner import ( + change_basis_real_to_complex, + su2_generators, + so3_generators, + wigner_D, + wigner_3j +) + +class TestWigner: + """Test wigner module functions.""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_change_basis_real_to_complex(self): + """Test change_basis_real_to_complex function.""" + # Test basic functionality + result = change_basis_real_to_complex(1) + assert result.shape == (3, 3) + assert result.dtype == complex64 + + # Test dtype conversion + result = change_basis_real_to_complex(1, dtype=float64) + assert result.dtype == complex128 + + # Test unitarity property + q_matrix = change_basis_real_to_complex(1) + q_np = q_matrix.asnumpy() + identity = np.eye(3) + np.testing.assert_allclose(q_np @ q_np.conj().T, identity, atol=1e-6) + + def test_su2_generators(self): + """Test su2_generators function.""" + # Test basic functionality + result = su2_generators(1) + assert result.shape == (3, 3, 3) + assert result.dtype == complex64 + + # Test dtype + result = su2_generators(1, dtype=complex128) + assert result.dtype == complex128 + + # Test invalid input + with pytest.raises(TypeError): + su2_generators(1.5) + + def test_so3_generators(self): + """Test so3_generators function.""" + # Test basic functionality + result = so3_generators(1) + assert result.shape == (3, 3, 3) + assert result.dtype == float32 + + # Test dtype + result = so3_generators(1, dtype=float64) + assert result.dtype == float64 + + # Test invalid input + with pytest.raises(TypeError): + so3_generators(1.5) + + def test_wigner_d(self): + """Test wigner_D function.""" + # Test identity rotation + result = wigner_D(1, 0, 0, 0) + assert result.shape == (3, 3) + expected = np.eye(3) + np.testing.assert_allclose(result.asnumpy(), expected, atol=1e-6) + + # Test orthogonality property + d_matrix = wigner_D(1, 0.5, 0.3, 0.7) + identity = np.eye(3) + np.testing.assert_allclose((d_matrix @ d_matrix.T).asnumpy(), identity, atol=1e-5) + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_wigner_3j(self): + """Test wigner_3j function.""" + # Test basic functionality + result = wigner_3j(1, 1, 1) + assert result.shape == (3, 3, 3) + assert result.dtype == float32 + + # Test dtype + result = wigner_3j(1, 1, 0, dtype=float64) + assert result.dtype == float64 + + # Test normalization property + coeffs = wigner_3j(1, 1, 1) + norm_squared = np.sum(coeffs.asnumpy() ** 2) + np.testing.assert_allclose(norm_squared, 1.0, atol=1e-6) + + # Test invalid combinations + with pytest.raises(ValueError): + wigner_3j(1, 1, 3) + + with pytest.raises(TypeError): + wigner_3j(1.5, 1, 1) diff --git a/tests/e3nn/utils/test_utils.py b/tests/e3nn/utils/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94825bdd0184e8e281d8dd20aebe950f8729e568 --- /dev/null +++ b/tests/e3nn/utils/test_utils.py @@ -0,0 +1,189 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test cases for e3nn.utils module. + +This module contains comprehensive test cases for all utility functions +in the e3nn.utils package, including tensor operations, linear algebra, +tensor contractions, radius computations, and initialization utilities. +""" + +import pytest +import numpy as np +import mindspore as ms +from mindspore import Tensor +from mindspore.common.initializer import TruncatedNormal + +from mindscience.e3nn.utils.func import broadcast_args, _ndexpm, narrow +from mindscience.e3nn.utils.linalg import _direct_sum +from mindscience.e3nn.utils.ncon import Ncon +from mindscience.e3nn.utils.radius import radius, radius_graph, radius_full +from mindscience.e3nn.utils.initializer import Uniform, renormal_initializer +from mindscience.e3nn.utils.perm import _from_int, _to_int, _inverse, _compose, _group, _germinate + +class TestFuncModule: + """Test cases for func.py module.""" + + def test_broadcast_and_operations(self): + """Test broadcasting, matrix exponential, and tensor slicing.""" + # Test broadcasting + a = Tensor([1.0, 2.0]) + b = Tensor([[3.0], [4.0]]) + result = broadcast_args(a, b) + assert len(result) == 2 and result[0].shape == (2, 2) + + # Test matrix exponential + mat = Tensor([[0.0, 1.0], [-1.0, 0.0]], dtype=ms.float32) + exp_result = _ndexpm(mat) + assert exp_result.shape == (2, 2) + + # Test tensor slicing + x = Tensor(np.arange(24).reshape(2, 3, 4), dtype=ms.float32) + sliced = narrow(x, axis=0, start=0, length=1) + assert sliced.shape == (1, 3, 4) + +class TestLinalgModule: + """Test cases for linalg.py module.""" + + def test_direct_sum(self): + """Test direct sum of matrices.""" + a = Tensor([[1.0, 2.0], [3.0, 4.0]]) + b = Tensor([[5.0]]) + result = _direct_sum(a, b) + assert result.shape == (3, 3) + + # Test with batch dimensions + batch_a = Tensor(np.random.randn(2, 3, 3).astype(np.float32)) + batch_b = Tensor(np.random.randn(2, 2, 2).astype(np.float32)) + batch_result = _direct_sum(batch_a, batch_b) + assert batch_result.shape == (2, 5, 5) + +class TestNconModule: + """Test cases for ncon.py module.""" + + def test_ncon_operations(self): + """Test various Ncon tensor contraction operations.""" + # Test trace + ncon_trace = Ncon([[1, 1]]) + a = Tensor([[1.0, 2.0], [3.0, 4.0]]) + trace_result = ncon_trace([a]) + assert np.isclose(trace_result.asnumpy(), 5.0) + + # Test outer product + ncon_outer = Ncon([[-1], [-2]]) + b = Tensor([1.0, 2.0]) + c = Tensor([3.0, 4.0, 5.0]) + outer_result = ncon_outer([b, c]) + assert outer_result.shape == (2, 3) + + # Test batch matrix multiplication + ncon_matmul = Ncon([[-1, -2, 1], [-1, 1, -3]]) + d = Tensor(np.random.randn(2, 3, 4).astype(np.float32)) + e = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) + matmul_result = ncon_matmul([d, e]) + assert matmul_result.shape == (2, 3, 5) + +class TestRadiusModule: + """Test cases for radius.py module.""" + + def test_radius_functions(self): + """Test radius computation functions.""" + np.random.seed(42) + x = np.random.random((8, 3)).astype(np.float32) + y = np.random.random((5, 3)).astype(np.float32) + + # Test basic radius + edge_index, batch_x, _ = radius(x, y, 0.5, max_num_neighbors=10) + assert edge_index.shape[0] == 2 and len(batch_x) == len(x) + + # Test radius_graph + edge_index, batch = radius_graph(x, 0.8, loop=False) + assert edge_index.shape[0] == 2 and len(batch) == len(x) + + # Test radius_full + x_batch = np.ones((2, 3, 3), dtype=np.float32) + edge_index_full, batch_x_full, _ = radius_full(x_batch, x_batch) + assert edge_index_full.shape[0] == 2 and len(batch_x_full) == 6 + +class TestInitializerModule: + """Test cases for initializer.py module.""" + + def test_initializers(self): + """Test custom initializers.""" + # Test Uniform initializer + from mindspore.common.initializer import initializer + uniform_init = Uniform(scale=2.0) + tensor = initializer(uniform_init, [3, 4], ms.float32) + values = tensor.asnumpy() + assert np.all(values >= 0.0) and np.all(values <= 2.0) + + # Test renormal_initializer + init1 = renormal_initializer('uniform') + assert isinstance(init1, Uniform) + + init2 = renormal_initializer('truncatedNormal') + assert isinstance(init2, TruncatedNormal) + + # Test invalid input + with pytest.raises(ValueError): + renormal_initializer('invalid_method') + +class TestPermModule: + """Test cases for perm.py module.""" + + def test_permutation_operations(self): + """Test permutation conversion and operations.""" + # Test conversion functions + n = 3 + for i in range(6): # 3! = 6 + perm = _from_int(i, n) + assert len(perm) == n and _to_int(perm) == i + + # Test permutation operations + perm1 = (0, 2, 1) + inv_perm1 = _inverse(perm1) + composed = _compose(perm1, inv_perm1) + assert composed == (0, 1, 2) # identity + + # Test group operations + group3 = _group(3) + assert len(group3) == 6 # 3! = 6 + + subset = {(0, 1, 2), (1, 0, 2)} + closure = _germinate(subset) + assert len(closure) >= len(subset) + +class TestInputValidation: + """Test input validation and error handling.""" + + def test_error_handling(self): + """Test various error conditions.""" + # Test radius with mismatched dimensions + x = np.random.random((5, 3)) + y = np.random.random((5, 4)) # Different last dimension + with pytest.raises(ValueError): + radius(x, y, 1.0) + + # Test radius_graph with invalid flow + with pytest.raises(ValueError): + radius_graph(x, 1.0, flow='invalid_flow') + + # Test _ndexpm with invalid input + invalid_mat = Tensor([1.0]) # 1D tensor + with pytest.raises(ValueError): + _ndexpm(invalid_mat) + +if __name__ == "__main__": + pytest.main([__file__])