代码拉取完成,页面将自动刷新
torch.topk(
input,
k,
dim=None,
largest=True,
sorted=True,
out=None
)
更多内容详见torch.topk。
class mindspore.ops.TopK(
sorted=False
)(input_x, k)
更多内容详见mindspore.ops.TopK。
PyTorch:支持获取指定维度的前k项最大或最小值。
MindSpore:目前仅支持获取最后维度的前k项最大值。
import mindspore as ms
import mindspore.ops as ops
import torch
# In MindSpore, obtain the first k largest entries of the last dimension.
topk = ops.TopK()
k = 3
input_x = ms.Tensor([[1, 2, 3, 4], [2, 4, 6, 8]], ms.float16)
values, indices = topk(input_x, k)
print(values)
print(indices)
# Out:
# [[4. 3. 2.]]
# [[8. 6. 4.]]
# [[3 2 1]]
# [[3 2 1]]
# In torch, obtain the first k largest or smallest entries of a specific dimension.
# largest=True
input_x = torch.tensor([[1, 2, 3, 4], [2, 4, 6, 8]], dtype=torch.float)
dim = 1
output = torch.topk(input_x, k, dim=dim, largest=True)
print(output)
# Out:
# torch.return_types.topk(
# values=tensor([[4., 3., 2.],
# [8., 6., 4.]]),
# indices=tensor([[3, 2, 1],
# [3, 2, 1]]))
# largest=False
output = torch.topk(input_x, k, dim=dim, largest=False)
print(output)
# Out:
# torch.return_types.topk(
# values=tensor([[1., 2., 3.],
# [2., 4., 6.]]),
# indices=tensor([[0, 1, 2],
# [0, 1, 2]]))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。