diff --git a/mindspore/python/mindspore/scipy/utils.py b/mindspore/python/mindspore/scipy/utils.py index f6b5117ba8b45e25ccab3983798e253451fb389e..3f90b3fcd3a63b877726d56a9b07339781e30ecf 100644 --- a/mindspore/python/mindspore/scipy/utils.py +++ b/mindspore/python/mindspore/scipy/utils.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """internal utility functions""" -from .. import nn, ops +from .. import ops from .. import numpy as mnp from ..numpy import where, zeros_like, dot, greater from ..ops import functional as F @@ -68,30 +68,21 @@ def _eps(x): return _eps_net(x[(0,) * x.ndim]) -class _SafeNormalize(nn.Cell): +def _safe_normalize(x, threshold=None): """Normalize method that cast very small results to zero.""" - - def __init__(self): - """Initialize LineSearch.""" - super(_SafeNormalize, self).__init__() - - def construct(self, x, threshold=None): - x_sum2 = F.reduce_sum(F.pows(x, 2.0)) - norm = F.pows(x_sum2, 1. / 2.0) - if threshold is None: - if x.dtype in mstype.float_type: - # pick the first element of x to get the eps - threshold = _eps(x) - else: - threshold = 0 - use_norm = greater(norm, threshold) - x_norm = x / norm - normalized_x = where(use_norm, x_norm, zeros_like(x)) - norm = where(use_norm, norm, zeros_like(norm)) - return normalized_x, norm - - -_safe_normalize = _SafeNormalize() + x_sum2 = F.reduce_sum(F.pows(x, 2.0)) + norm = F.pows(x_sum2, 1. / 2.0) + if threshold is None: + if x.dtype in (mstype.float32, mstype.float64): + # pick the first element of x to get the eps + threshold = _eps(x) + else: + threshold = 0 + use_norm = greater(norm, threshold) + x_norm = x / norm + normalized_x = where(use_norm, x_norm, zeros_like(x)) + norm = where(use_norm, norm, zeros_like(norm)) + return normalized_x, norm def _normalize_matvec(f):