代码拉取完成,页面将自动刷新
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""normact"""
from mindspore import nn, Parameter, float32, ops
from mindspore.common.initializer import initializer
from ..o3.irreps import Irreps
from ..o3.tensor_product import TensorProduct
from ..o3.norm import Norm
class NormActivation(nn.Cell):
r"""Activation function for the norm of irreps.
Applies a scalar activation to the norm of each irrep and outputs a (normalized) version of that irrep multiplied
by the scalar output of the scalar activation.
Args:
irreps_in (Union[str, Irrep, Irreps]): the input irreps.
act (Func): an activation function for each part of the norm of `irreps_in`.
normalize (bool): whether to normalize the input features before multiplying them by the scalars from the
nonlinearity. Default: True.
epsilon (float): when ``normalize``, norms smaller than ``epsilon`` will be clamped up to ``epsilon``
to avoid division by zero. Not allowed when `normalize` is False. Default: None.
bias (bool): whether to apply a learnable additive bias to the inputs of the `act`. Default: False.
init_method (Union[str, float, mindspore.common.initializer]): initialize parameters.
Default: ``'normal'``.
dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``.
ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module.
Default: ``mindspore.float32``.
Inputs:
- **input** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`.
Outputs:
- **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`.
Raises:
ValueError: If `epsilon` is not None and `normalize` is False.
ValueError: If `epsilon` is not positive.
Supported Platforms:
``Ascend``
Examples:
>>> from mindchemistry.e3.nn import NormActivation
>>> from mindspore import ops, Tensor
>>> set_context(device_id=6)
>>> norm_activation = NormActivation("2x1e", ops.sigmoid, bias=True)
>>> print(norm_activation)
NormActivation [sigmoid] (2x1e -> 2x1e)
>>> inputs = Tensor(ops.ones((4, 6)))
>>> outputs = norm_activation(inputs)
>>> print(outputs.shape)
(4, 6)
"""
def __init__(self,
irreps_in,
act,
normalize=True,
epsilon=None,
bias=False,
init_method='zeros',
dtype=float32,
ncon_dtype=float32):
super().__init__()
self.irreps_in = Irreps(irreps_in)
self.irreps_out = Irreps(irreps_in)
if epsilon is None and normalize:
epsilon = 1e-8
elif epsilon is not None and not normalize:
raise ValueError("`epsilon` and `normalize = False` don't make sense together.")
elif not epsilon > 0:
raise ValueError(f"epsilon {epsilon} is invalid, must be strictly positive.")
self.epsilon = epsilon
if self.epsilon is not None:
self._eps_squared = epsilon * epsilon
else:
self._eps_squared = 0.0
self.norm = Norm(irreps_in, squared=(epsilon is not None), dtype=dtype)
self.act = act
self.normalize = normalize
if bias:
self.bias = Parameter(initializer(init_method, (self.irreps_in.num_irreps,), dtype),
name=self.__class__.__name__)
else:
self.bias = None
self.scalar_multiplier = TensorProduct(irreps_in1=self.norm.irreps_out,
irreps_in2=irreps_in,
instructions='element',
dtype=dtype,
ncon_dtype=ncon_dtype)
def construct(self, v):
"""Implement the norm-activation function for the input tensor."""
norms = self.norm(v)
if self._eps_squared > 0:
norms[norms < self._eps_squared] = self._eps_squared
norms = ops.sqrt(norms)
nonlin_arg = norms
if self.bias is not None:
nonlin_arg = nonlin_arg + self.bias
scalings = self.act(nonlin_arg)
if self.normalize:
scalings = scalings / norms
return self.scalar_multiplier(scalings, v)
def __repr__(self):
return f"{self.__class__.__name__} [{self.act.__name__}] ({self.irreps_in} -> {self.irreps_in})"
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。