2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
functional.py 22.50 KB
一键复制 编辑 原始数据 按行查看 历史
bantao 提交于 2024-03-05 15:48 . neg pr
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The names of functional part are summarized here."""
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.ops import _constants
from mindspore.ops.function import *
from mindspore.ops.function.array_func import narrow, flatten
from mindspore.ops.function.math_func import all
from mindspore.ops import operations as P
from mindspore.ops.operations import array_ops
from mindspore.ops.operations._sequence_ops import TensorToTuple
from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _grad_ops, _csr_ops, _inner_ops, linalg_ops, _sequence_ops, other_ops
from mindspore.ops.operations.math_ops import Median
from mindspore.ops.operations.array_ops import UniqueConsecutive
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
from mindspore.ops.operations.math_ops import Roll
from mindspore.ops.composite.math_ops import mm
from mindspore.ops.function.math_func import dot
from mindspore.ops import auto_generate
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
from mindspore.ops.operations.manually_defined.ops_def import scalar_div, scalar_mod, scalar_add, scalar_mul,\
scalar_sub, scalar_gt, scalar_ge, scalar_le, scalar_lt, scalar_eq, scalar_floordiv, scalar_log, scalar_pow,\
scalar_uadd, scalar_usub
typeof = Primitive('typeof')
hastype = Primitive('hastype')
cast = P.Cast()
dtype = P.DType()
isconstant = _inner_ops.IsConstant()
isconstant.set_const_prim(True)
merge = P.Merge()
geswitch = P.GeSwitch()
reduce_sum = P.ReduceSum()
reduce_max = P.ReduceMax()
reduce_min = P.ReduceMin()
reduce_mean = P.ReduceMean()
tensor_range = P.Range()
tensor_scatter_update = P.TensorScatterUpdate()
scatter_nd_update = P.ScatterNdUpdate()
mixed_precision_cast = _inner_ops.MixedPrecisionCast()
_py_interpret = other_ops.PyInterpret()
_dtype_to_enum = DtypeToEnum()
# Dynamic shape
is_sequence_value_unknown = Primitive("IsShapeUnKnown")
is_sequence_shape_unknown = Primitive("IsDimUnKnown")
is_dynamic_sequence_element_unknown = Primitive("IsElementUnknown")
is_tensor_bool_cond = Primitive("IsTensorBoolCond")
partial = P.Partial()
# depend: mount a node to another node
depend = P.Depend()
identity = P.identity()
# tuple/list/scalar ops
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive(_constants.kTupleGetItem)
list_getitem = Primitive('list_getitem')
list_setitem = Primitive('list_setitem')
dict_getitem = Primitive('dict_getitem')
dict_setitem = Primitive('dict_setitem')
tuple_div = Primitive("tuple_div")
tuple_len = Primitive("sequence_len")
list_len = Primitive("sequence_len")
tuple_reversed = Primitive("tuple_reversed")
make_range = Primitive("make_range")
make_tuple = Primitive('MakeTuple')
make_dict = Primitive('make_dict')
make_list = Primitive('make_list')
make_slice = Primitive('make_slice')
tuple_equal = Primitive("tuple_equal")
list_equal = Primitive("list_equal")
scalar_ne = Primitive('scalar_ne')
string_eq = Primitive('string_eq')
string_concat = Primitive('string_concat')
bool_not = Primitive('BoolNot')
bool_or = Primitive("bool_or")
bool_and = Primitive("bool_and")
bool_eq = Primitive("bool_eq")
array_to_scalar = Primitive('array_to_scalar')
is_ = Primitive("is_")
is_not = Primitive("is_not")
in_dict = Primitive("in_dict")
not_in_dict = Primitive("not_in_dict")
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
array_reduce = Primitive('array_reduce')
distribute = Primitive('distribute')
embed = Primitive('embed')
ref_to_embed = _grad_ops.RefToEmbed()
environ_create = Primitive('EnvironCreate')
environ_set = Primitive('EnvironSet')
environ_get = Primitive('EnrironGet')
environ_add = Primitive('EnvironAdd')
J = Primitive('J')
SliceGetItem = Primitive("SliceGetItem")
switch = Primitive('Switch')
switch_layer = Primitive('switch_layer')
# for sum bprop
reduced_shape = Primitive("reduced_shape")
# shape_mul:input must be shape multiply elements in tuple(shape)
shape_mul = _sequence_ops.shape_mul()
tensor_operator_registry.register('tuple_to_tensor', _sequence_ops.TupleToTensor)
tensor_operator_registry.register('add', add)
tensor_operator_registry.register('softmax', softmax)
tensor_operator_registry.register('addr', addr)
tensor_operator_registry.register('addcdiv', addcdiv)
tensor_operator_registry.register('addcmul', addcmul)
tensor_operator_registry.register('all', all)
tensor_operator_registry.register('angle', angle)
tensor_operator_registry.register('any', any)
tensor_operator_registry.register('atan2', atan2)
tensor_operator_registry.register('abs', abs)
tensor_operator_registry.register('baddbmm', baddbmm)
tensor_operator_registry.register('geqrf', geqrf)
tensor_operator_registry.register('histc', histc)
tensor_operator_registry.register('real', real)
tensor_operator_registry.register('reciprocal', reciprocal)
tensor_operator_registry.register('rsqrt', rsqrt)
tensor_operator_registry.register('bincount', bincount)
tensor_operator_registry.register('slogdet', slogdet)
tensor_operator_registry.register('trace', trace)
tensor_operator_registry.register('tril', tril)
tensor_operator_registry.register('chunk', chunk)
tensor_operator_registry.register('count_nonzero', count_nonzero)
tensor_operator_registry.register('sqrt', sqrt)
tensor_operator_registry.register('square', square)
tensor_operator_registry.register('sub', sub)
tensor_operator_registry.register('triu', triu)
tensor_operator_registry.register('tan', tan)
tensor_operator_registry.register('t', t)
tensor_operator_registry.register('cauchy', P.Cauchy)
tensor_operator_registry.register('log_normal', P.LogNormalReverse)
tensor_operator_registry.register('acos', acos)
tensor_operator_registry.register('cos', cos)
tensor_operator_registry.register('acosh', acosh)
tensor_operator_registry.register('cosh', cosh)
tensor_operator_registry.register('cov', cov)
tensor_operator_registry.register('asin', asin)
tensor_operator_registry.register('sin', sin)
tensor_operator_registry.register('sinc', sinc)
tensor_operator_registry.register('pow', pow)
tensor_operator_registry.register('negative', neg)
tensor_operator_registry.register('amin', amin)
tensor_operator_registry.register('amax', amax)
tensor_operator_registry.register('aminmax', aminmax)
tensor_operator_registry.register('mean', mean)
tensor_operator_registry.register('prod', prod)
tensor_operator_registry.register('round', round)
tensor_operator_registry.register('reshape', reshape)
tensor_operator_registry.register('reverse', reverse)
tensor_operator_registry.register('reverse_sequence', reverse_sequence)
tensor_operator_registry.register('xlogy', xlogy)
tensor_operator_registry.register('flatten', flatten)
tensor_operator_registry.register('transpose', transpose)
tensor_operator_registry.register('broadcast_to', broadcast_to)
tensor_operator_registry.register('matmul', matmul)
tensor_operator_registry.register('inner', inner)
tensor_operator_registry.register('xdivy', xdivy)
tensor_operator_registry.register('argmax', argmax)
tensor_operator_registry.register('argmin', argmin)
tensor_operator_registry.register('cumsum', P.CumSum)
tensor_operator_registry.register('cummin', cummin)
tensor_operator_registry.register('cummax', cummax)
tensor_operator_registry.register('nelement', numel)
tensor_operator_registry.register('numel', numel)
tensor_operator_registry.register('positive', positive)
tensor_operator_registry.register('permute', permute)
tensor_operator_registry.register('remainder', remainder)
tensor_operator_registry.register('index_fill', index_fill)
tensor_operator_registry.register('index_select', index_select)
tensor_operator_registry.register('flip', flip)
tensor_operator_registry.register('fliplr', fliplr)
tensor_operator_registry.register('flipud', flipud)
tensor_operator_registry.register('float_power', float_power)
tensor_operator_registry.register('fmax', fmax)
tensor_operator_registry.register('fmin', fmin)
tensor_operator_registry.register('fmod', fmod)
tensor_operator_registry.register('is_floating_point', is_floating_point)
tensor_operator_registry.register('bitwise_and', bitwise_and)
tensor_operator_registry.register('bitwise_or', bitwise_or)
tensor_operator_registry.register('bitwise_xor', bitwise_xor)
tensor_operator_registry.register('bitwise_left_shift', bitwise_left_shift)
tensor_operator_registry.register('bitwise_right_shift', bitwise_right_shift)
tensor_operator_registry.register('ger', ger)
tensor_operator_registry.register('reduce_max', P.ReduceMax)
tensor_operator_registry.register('reduce_min', P.ReduceMin)
tensor_operator_registry.register('random_categorical', random_categorical)
tensor_operator_registry.register('mirror_pad', P.MirrorPad)
tensor_operator_registry.register('minimum', minimum)
tensor_operator_registry.register('matrix_power', matrix_power)
tensor_operator_registry.register('det', det)
tensor_operator_registry.register('dot', dot)
tensor_operator_registry.register('outer', outer)
tensor_operator_registry.register('log1p', log1p)
tensor_operator_registry.register('logdet', logdet)
tensor_operator_registry.register('log_matrix_determinant', log_matrix_determinant)
tensor_operator_registry.register('matrix_determinant', matrix_determinant)
tensor_operator_registry.register('ceil', ceil)
tensor_operator_registry.register('fillv2', P.FillV2)
tensor_operator_registry.register('tile', tile)
tensor_operator_registry.register('logit', logit)
tensor_operator_registry.register('sum', sum)
tensor_operator_registry.register('split', split)
tensor_operator_registry.register('tensor_split', tensor_split)
tensor_operator_registry.register('vsplit', vsplit)
tensor_operator_registry.register('hsplit', hsplit)
tensor_operator_registry.register('dsplit', dsplit)
tensor_operator_registry.register('zeros_like', zeros_like)
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
tensor_operator_registry.register('stop_gradient', stop_gradient)
tensor_operator_registry.register('masked_fill', masked_fill)
tensor_operator_registry.register('masked_select', masked_select)
tensor_operator_registry.register('nonzero', nonzero)
tensor_operator_registry.register('i0', i0)
tensor_operator_registry.register('isclose', isclose)
tensor_operator_registry.register('isneginf', isneginf)
tensor_operator_registry.register('isposinf', isposinf)
tensor_operator_registry.register('isreal', isreal)
tensor_operator_registry.register('inv', inv)
tensor_operator_registry.register('digamma', digamma)
tensor_operator_registry.register('lgamma', lgamma)
tensor_operator_registry.register('logaddexp', logaddexp)
tensor_operator_registry.register('logaddexp2', logaddexp2)
tensor_operator_registry.register('logcumsumexp', logcumsumexp)
tensor_operator_registry.register('logsumexp', logsumexp)
tensor_operator_registry.register('inverse', inverse)
tensor_operator_registry.register('invert', invert)
tensor_operator_registry.register('hardshrink', hardshrink)
tensor_operator_registry.register('heaviside', heaviside)
tensor_operator_registry.register('hypot', hypot)
tensor_operator_registry.register('soft_shrink', soft_shrink)
tensor_operator_registry.register('svd', linalg_ops.Svd)
tensor_operator_registry.register('diag', diag)
tensor_operator_registry.register('diagflat', diagflat)
tensor_operator_registry.register('unique_consecutive', UniqueConsecutive)
tensor_operator_registry.register('unique_with_pad', unique_with_pad)
tensor_operator_registry.register('inplace_update', inplace_update)
tensor_operator_registry.register('col2im', col2im)
tensor_operator_registry.register('standard_laplace', P.StandardLaplace)
tensor_operator_registry.register('erf', erf)
tensor_operator_registry.register('erfc', erfc)
tensor_operator_registry.register('standard_normal', P.StandardNormal)
tensor_operator_registry.register('sigmoid', sigmoid)
tensor_operator_registry.register('median', Median)
tensor_operator_registry.register('tanh', tanh)
tensor_operator_registry.register('exp', exp)
tensor_operator_registry.register('addbmm', addbmm)
tensor_operator_registry.register('addmm', addmm)
tensor_operator_registry.register('addmv', addmv)
tensor_operator_registry.register('adjoint', adjoint)
tensor_operator_registry.register('asinh', asinh)
tensor_operator_registry.register('arcsinh', arcsinh)
tensor_operator_registry.register('atan', atan)
tensor_operator_registry.register('atanh', atanh)
tensor_operator_registry.register('arctanh', arctanh)
tensor_operator_registry.register('bmm', bmm)
tensor_operator_registry.register('conj', conj)
tensor_operator_registry.register('cross', cross)
tensor_operator_registry.register('erfinv', erfinv)
tensor_operator_registry.register('less_equal', less_equal)
tensor_operator_registry.register('lcm', lcm)
tensor_operator_registry.register('ldexp', ldexp)
tensor_operator_registry.register('clamp', clamp)
tensor_operator_registry.register('fold', fold)
tensor_operator_registry.register('unfold', unfold)
tensor_operator_registry.register('diagonal', diagonal)
tensor_operator_registry.register('diagonal_scatter', diagonal_scatter)
tensor_operator_registry.register('index_add', index_add)
tensor_operator_registry.register('greater', greater)
tensor_operator_registry.register('greater_equal', greater_equal)
tensor_operator_registry.register('igamma', igamma)
tensor_operator_registry.register('igammac', igammac)
tensor_operator_registry.register('lu_solve', lu_solve)
tensor_operator_registry.register('nextafter', nextafter)
tensor_operator_registry.register('qr', qr)
tensor_operator_registry.register('ormqr', ormqr)
tensor_operator_registry.register('masked_scatter', array_ops.MaskedScatter)
tensor_operator_registry.register('index_put', array_ops.IndexPut)
tensor_operator_registry.register('quantile', quantile)
tensor_operator_registry.register('nanquantile', nanquantile)
tensor_operator_registry.register('orgqr', orgqr)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)
tensor_operator_registry.register('__neg__', neg)
tensor_operator_registry.register('__lt__', tensor_lt)
tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('__logical_not__', logical_not)
tensor_operator_registry.register('gt', gt)
tensor_operator_registry.register('ge', ge)
tensor_operator_registry.register('shape', shape)
tensor_operator_registry.register('squeeze', squeeze)
tensor_operator_registry.register('unsqueeze', unsqueeze)
tensor_operator_registry.register('expand_dims', expand_dims)
tensor_operator_registry.register('contiguous', auto_generate.contiguous)
# support GE backend for no compare operators
tensor_operator_registry.register('cast', cast)
tensor_operator_registry.register('shape_mul', shape_mul)
tensor_operator_registry.register('concatenate', concat)
tensor_operator_registry.register('fill', fill)
tensor_operator_registry.register('fills', fills)
tensor_operator_registry.register('fill_diagonal', P.FillDiagonal)
tensor_operator_registry.register('eye', eye)
tensor_operator_registry.register('eigvals', eigvals)
tensor_operator_registry.register('reduce_sum', reduce_sum)
tensor_operator_registry.register('reducesum', P.ReduceSum)
tensor_operator_registry.register('tensor_slice', tensor_slice)
tensor_operator_registry.register('select', select)
tensor_operator_registry.register('gather', gather)
tensor_operator_registry.register('gather_d', gather_d)
tensor_operator_registry.register('gather_elements', gather_elements)
tensor_operator_registry.register('gather_nd', gather_nd)
tensor_operator_registry.register('stack', stack)
tensor_operator_registry.register('unstack', unstack)
tensor_operator_registry.register('unbind', unstack)
tensor_operator_registry.register('log', log)
tensor_operator_registry.register('log10', log10)
tensor_operator_registry.register('log2', log2)
tensor_operator_registry.register('lerp', lerp)
tensor_operator_registry.register('floor', floor)
tensor_operator_registry.register('floor_divide', floor_divide)
# support sparse tensor operators
tensor_operator_registry.register('csr_add', csr_add)
tensor_operator_registry.register('csr_mul', csr_mul)
tensor_operator_registry.register('csr2coo', csr2coo)
tensor_operator_registry.register('coo2csr', coo2csr)
tensor_operator_registry.register('csr_div', csr_div)
tensor_operator_registry.register('csr_mv', csr_mv)
tensor_operator_registry.register('csr_mm_akg', _csr_ops.CSRMM)
tensor_operator_registry.register('csr_mm', csr_mm)
tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum)
tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
tensor_operator_registry.register('csr_to_dense', csr_to_dense)
tensor_operator_registry.register('narrow', narrow)
tensor_operator_registry.register('sort', sort)
tensor_operator_registry.register('argsort', argsort)
tensor_operator_registry.register('msort', msort)
tensor_operator_registry.register('mm', mm)
tensor_operator_registry.register('nan_to_num', nan_to_num)
tensor_operator_registry.register('nansum', nansum)
tensor_operator_registry.register('nanmean', nanmean)
tensor_operator_registry.register('nanmedian', nanmedian)
tensor_operator_registry.register('csr_to_coo', csr_to_coo)
tensor_operator_registry.register('zeros', zeros)
tensor_operator_registry.register('ones', ones)
tensor_operator_registry.register('unsorted_segment_min', unsorted_segment_min)
tensor_operator_registry.register('unsorted_segment_max', unsorted_segment_max)
tensor_operator_registry.register('unsorted_segment_prod', unsorted_segment_prod)
tensor_operator_registry.register('scatter', scatter)
tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update)
tensor_operator_registry.register('tensor_scatter_mul', tensor_scatter_mul)
tensor_operator_registry.register('tensor_scatter_div', tensor_scatter_div)
tensor_operator_registry.register('tensor_scatter_min', tensor_scatter_min)
tensor_operator_registry.register('tensor_scatter_max', tensor_scatter_max)
tensor_operator_registry.register('tensor_scatter_sub', tensor_scatter_sub)
tensor_operator_registry.register('tensor_scatter_add', tensor_scatter_add)
tensor_operator_registry.register('slice_scatter', slice_scatter)
tensor_operator_registry.register('select_scatter', select_scatter)
tensor_operator_registry.register('bernoulli', bernoulli)
tensor_operator_registry.register('poisson', P.Poisson)
tensor_operator_registry.register('randperm', P.Randperm)
tensor_operator_registry.register('multinomial', multinomial)
tensor_operator_registry.register('norm', norm)
tensor_operator_registry.register('renorm', renorm)
tensor_operator_registry.register('adaptive_max_pool2d', AdaptiveMaxPool2D)
tensor_operator_registry.register('coalesce', coalesce)
tensor_operator_registry.register('argmax_with_value', max)
tensor_operator_registry.register('argmin_with_value', min)
tensor_operator_registry.register('argwhere', argwhere)
tensor_operator_registry.register('coo_add', coo_add)
tensor_operator_registry.register('topk', topk)
tensor_operator_registry.register('isfinite', isfinite)
tensor_operator_registry.register('to', cast)
tensor_operator_registry.register('bool', cast)
tensor_operator_registry.register('float', cast)
tensor_operator_registry.register('half', cast)
tensor_operator_registry.register('int', cast)
tensor_operator_registry.register('long', cast)
tensor_operator_registry.register('cholesky', cholesky)
tensor_operator_registry.register('cholesky_inverse', cholesky_inverse)
tensor_operator_registry.register('cholesky_solve', cholesky_solve)
tensor_operator_registry.register('expand', broadcast_to)
tensor_operator_registry.register('tensortotuple', TensorToTuple)
tensor_operator_registry.register('cumprod', cumprod)
tensor_operator_registry.register('diff', diff)
tensor_operator_registry.register('div', div)
tensor_operator_registry.register('equal', equal)
tensor_operator_registry.register('expm1', expm1)
tensor_operator_registry.register('frac', frac)
tensor_operator_registry.register('isinf', isinf)
tensor_operator_registry.register('isnan', isnan)
tensor_operator_registry.register('is_complex', is_complex)
tensor_operator_registry.register('le', le)
tensor_operator_registry.register('less', less)
tensor_operator_registry.register('logical_and', logical_and)
tensor_operator_registry.register('logical_not', logical_not)
tensor_operator_registry.register('logical_or', logical_or)
tensor_operator_registry.register('logical_xor', logical_xor)
tensor_operator_registry.register('lstsq', lstsq)
tensor_operator_registry.register('mvlgamma', mvlgamma)
tensor_operator_registry.register('maximum', maximum)
tensor_operator_registry.register('max', max)
tensor_operator_registry.register('min', min)
tensor_operator_registry.register('mul', mul)
tensor_operator_registry.register('multiply', multiply)
tensor_operator_registry.register('moveaxis', moveaxis)
tensor_operator_registry.register('movedim', movedim)
tensor_operator_registry.register('neg', neg)
tensor_operator_registry.register('ne', ne)
tensor_operator_registry.register('not_equal', not_equal)
tensor_operator_registry.register('sgn', sgn)
tensor_operator_registry.register('sign', sign)
tensor_operator_registry.register('signbit', signbit)
tensor_operator_registry.register('sinh', sinh)
tensor_operator_registry.register('trunc', trunc)
tensor_operator_registry.register('where', where)
tensor_operator_registry.register('imag', imag)
tensor_operator_registry.register('repeat_interleave', repeat_interleave)
tensor_operator_registry.register('rad2deg', rad2deg)
tensor_operator_registry.register('deg2rad', deg2rad)
tensor_operator_registry.register('copysign', copysign)
tensor_operator_registry.register('roll', Roll)
tensor_operator_registry.register('rot90', rot90)
tensor_operator_registry.register('swapaxes', swapaxes)
tensor_operator_registry.register('swapdims', swapdims)
tensor_operator_registry.register('repeat_elements', repeat_elements)
tensor_operator_registry.register('top_k', top_k)
__all__ = [name for name in dir() if name[0] != "_"]
__all__.remove('Primitive')
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
master

搜索帮助