代码拉取完成,页面将自动刷新
torch.flatten(
input,
start_dim=0,
end_dim=-1
)
For more information, see torch.flatten.
class mindspore.ops.Flatten(*args, **kwargs)(input_x)
For more information, see mindspore.ops.Flatten.
PyTorch: Supports the flatten of elements by specified dimensions.
MindSpore:Only the 0th dimension element is reserved and the elements of the remaining dimensions are flattened.
import mindspore as ms
import mindspore.ops as ops
import torch
import numpy as np
# In MindSpore, only the 0th dimension will be reserved and the rest will be flattened.
input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
flatten = ops.Flatten()
output = flatten(input_tensor)
print(output.shape)
# Out:
# (1, 24)
# In torch, the dimension to reserve will be specified and the rest will be flattened.
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
output1 = torch.flatten(input=input_tensor, start_dim=1)
print(output1.shape)
# Out:
# torch.Size([1, 24])
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
output2 = torch.flatten(input=input_tensor, start_dim=2)
print(output2.shape)
# Out:
# torch.Size([1, 2, 12])
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。