代码拉取完成,页面将自动刷新
torch.nn.MaxPool3d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)(input) -> Tensor
For more information, see torch.nn.MaxPool3d.
mindspore.nn.MaxPool3d(kernel_size=1, stride=1, pad_mode="valid", padding=0, dilation=1, return_indices=False, ceil_mode=False)(x) -> Tensor
For more information, see mindspore.nn.MaxPool3d.
PyTorch: Perform three-dimensional maximum pooling operations on the input multidimensional data.
MindSpore: This API implementation function of MindSpore is compatible with TensorFlow and PyTorch, When pad_mode
is "valid" or "same", the function is consistent with TensorFlow, and when pad_mode
is "pad", the function is consistent with PyTorch, MindSpore additionally supports 2D input, which is consistent with PyTorch 1.12.
Categories | Subcategories | PyTorch | MindSpore | Difference |
---|---|---|---|---|
Parameters | Parameter 1 | kernel_size | kernel_size | Consistent function, no default values for PyTorch |
Parameter 2 | stride | stride | Consistent function, different default value | |
Parameter 3 | padding | padding | Consistent | |
Parameter 4 | dilation | dilation | Consistent | |
Parameter 5 | return_indices | return_indices | Consistent | |
Parameter 6 | ceil_mode | ceil_mode | Consistent | |
Parameter 7 | input | x | Consistent function, different parameter names | |
Parameter 8 | - | pad_mode | Control the padding mode, and PyTorch does not have this parameter |
Use pad mode to ensure functional consistency.
import mindspore as ms
from mindspore import Tensor
import mindspore.nn as nn
import torch
import numpy as np
np_x = np.random.randint(0, 10, [1, 2, 4, 4, 5])
x = Tensor(np_x, ms.float32)
max_pool = nn.MaxPool3d(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=1, return_indices=False)
output = max_pool(x)
result = output.shape
print(result)
# (1, 2, 5, 5, 6)
x = torch.tensor(np_x, dtype=torch.float32)
max_pool = torch.nn.MaxPool3d(kernel_size=2, stride=1, padding=1, dilation=1, return_indices=False)
output = max_pool(x)
result = output.shape
print(result)
# torch.Size([1, 2, 5, 5, 6])
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。