torch.dot(input, other, *, out=None)
For more information, see torch.dot.
mindspore.ops.tensor_dot(x1, x2, axes)
For more information, see mindspore.ops.tensor_dot.
API function of MindSpore is not consistent with that of PyTorch.
PyTorch: Calculates the dot product (inner product) of two tensors of the same shape, only 1D is supported. The supported input data types include uint8, int8/16/32/64, float32/64.
MindSpore: Calculates the dot product of two tensors on any axis. Support tensor of any dimension, but the shape corresponding to the specified axis should be equal. The function of the PyTorch is the same when the input is 1D and the axis is set to 0. The supported input data types are float16 or float32.
Categories | Subcategories | PyTorch | MindSpore | Differences |
---|---|---|---|---|
Parameters | Parameter 1 | input | x1 | Different parameter names |
Parameter 2 | other | x2 | Different parameter names | |
Parameter 3 | out | - | For details, see General Difference Parameter Table | |
Parameter 4 | - | axes | The function of the PyTorch is the same when the input is 1D and the axis is set to 0. |
The data type of the input is int, and the data type of the output is also int.
import torch
import numpy as np
input_x1 = torch.tensor([2, 3, 4], dtype=torch.int32)
input_x2 = torch.tensor([2, 1, 3], dtype=torch.int32)
output = torch.dot(input_x1, input_x2)
print(output)
print(output.dtype)
# tensor(19)
# torch.int32
# MindSpore doesn't support this feature currently.
The data type of the input is float, and the data type of the output is also float.
import torch
import numpy as np
input_x1 = torch.tensor([2, 3, 4], dtype=torch.float32)
input_x2 = torch.tensor([2, 1, 3], dtype=torch.float32)
output = torch.dot(input_x1, input_x2)
print(output)
print(output.dtype)
# tensor(19)
# torch.float32
import mindspore as ms
import mindspore.ops as ops
import numpy as np
input_x1 = ms.Tensor(np.array([2, 3, 4]), ms.float32)
input_x2 = ms.Tensor(np.array([2, 1, 3]), ms.float32)
output = ops.tensor_dot(input_x1, input_x2, 1)
print(output)
print(output.dtype)
# 19.0
# Float32
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。