335 Star 1.5K Fork 865

MindSpore / docs

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
prod.md 2.57 KB
一键复制 编辑 原始数据 按行查看 历史
luojianing 提交于 2023-07-21 15:16 . replace target=blank

Function Differences with torch.prod

View Source On Gitee

The following mapping relationships can be found in this file.

PyTorch APIs MindSpore APIs
torch.prod mindspore.ops.prod
torch.Tensor.prod mindspore.Tensor.prod

torch.prod

torch.prod(input, dim, keepdim=False, *, dtype=None) -> Tensor

For more information, see torch.prod.

mindspore.ops.prod

mindspore.ops.prod(input, axis=(), keep_dims=False) -> Tensor

For more information, see [mindspore.ops.prod](https://mindspore.cn/docs/en/r2.0/api_python/ops/mindspore.ops.prod.html.

Differences

PyTorch: Find the product on elements in input based on the specified dim. keepdim controls whether the output and input have the same dimension. dtype sets the data type of the output Tensor.

MindSpore: Find the product on the elements in input by the specified axis. The function of keep_dims is the same as PyTorch. MindSpore does not have a dtype parameter. MindSpore has a default value for axis, which is the product of all elements of input if axis is the default value.

Categories Subcategories PyTorch MindSpore Differences
Parameters Parameter 1 input input Consistent
Parameter 2 dim axis PyTorch must pass dim and only one integer. MindSpore axis can be passed as an integer, a tuples of integers or a list of integers
Parameter 3 keepdim keep_dims Same function, different parameter names
Parameter 4 dtype - PyTorch dtype can set the data type of the output Tensor. MindSpore does not have this parameter

Code Example

# PyTorch
import torch

input = torch.tensor([[1, 2.5, 3, 1], [2.5, 3, 2, 1]], dtype=torch.float32)
print(torch.prod(input, dim=1, keepdim=True))
# tensor([[ 7.5000],
#         [15.0000]])
print(torch.prod(input, dim=1, keepdim=True, dtype=torch.int32))
# tensor([[ 6],
#         [12]], dtype=torch.int32)

# MindSpore
import mindspore

x = mindspore.Tensor([[1, 2.5, 3, 1], [2.5, 3, 2, 1]], dtype=mindspore.float32)
print(mindspore.ops.prod(x, axis=1, keep_dims=True))
# [[ 7.5]
#  [15. ]]
1
https://gitee.com/mindspore/docs.git
git@gitee.com:mindspore/docs.git
mindspore
docs
docs
r2.0

搜索帮助

53164aa7 5694891 3bd8fe86 5694891