335 Star 1.5K Fork 861

MindSpore / docs

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
tensor_dot.md 2.72 KB
一键复制 编辑 原始数据 按行查看 历史
TingWang 提交于 2023-09-18 10:08 . update link logo

比较与torch.dot的差异

查看源文件

torch.dot

torch.dot(input, other, *, out=None)

更多内容详见torch.dot

mindspore.ops.tensor_dot

mindspore.ops.tensor_dot(x1, x2, axes)

更多内容详见mindspore.ops.tensor_dot

使用方式

MindSpore此API功能与PyTorch不一致。

PyTorch:计算两个相同shape的tensor的点乘(内积),仅支持1D。支持的输入数据类型包括uint8、int8/16/32/64、float32/64。

MindSpore:计算两个tensor在任意轴上的点乘,支持任意维度的tensor,但指定的轴对应的形状要相等。当输入为1D, 轴设定为1时,和PyTorch的功能一致。支持的输入数据类型为float16或float32。

分类 子类 PyTorch MindSpore 差异
参数 参数 1 input x1 参数名不同
参数 2 other x2 参数名不同
参数 3 out - 详见通用差异参数表
参数 4 - axes 当输入为1D,axes设定为1时,和PyTorch的功能一致

代码示例 1

输入的数据类型是int,输出的数据类型也是int。

import torch
input_x1 = torch.tensor([2, 3, 4], dtype=torch.int32)
input_x2 = torch.tensor([2, 1, 3], dtype=torch.int32)
output = torch.dot(input_x1, input_x2)
print(output)
# tensor(19, dtype=torch.int32)
print(output.dtype)
# torch.int32

# MindSpore目前无法支持该功能。

代码示例 2

输入的数据类型是float,输出的数据类型也是float。

import torch
input_x1 = torch.tensor([2, 3, 4], dtype=torch.float32)
input_x2 = torch.tensor([2, 1, 3], dtype=torch.float32)
output = torch.dot(input_x1, input_x2)
print(output)
# tensor(19.)
print(output.dtype)
# torch.float32

import mindspore as ms
import mindspore.ops as ops
import numpy as np
input_x1 = ms.Tensor(np.array([2, 3, 4]), ms.float32)
input_x2 = ms.Tensor(np.array([2, 1, 3]), ms.float32)
output = ops.tensor_dot(input_x1, input_x2, 1)
print(output)
# 19.0
print(output.dtype)
# Float32
1
https://gitee.com/mindspore/docs.git
git@gitee.com:mindspore/docs.git
mindspore
docs
docs
master

搜索帮助