由社区提供的PyTorch APIs和MindSpore APIs之间的映射,可能在参数、输入、输出、逻辑功能和特定场景等方面存在差异,可详见各API描述或已提供的差异对比。
也欢迎更多的MindSpore开发者参与完善映射内容。关于PyTorch与MindSpore关于框架机制差异,请参考:优化器对比,随机数策略对比,参数初始化对比。
API映射一致标准:API功能一致,参数个数或顺序一致,参数数据类型一致,参数默认值一致,参数名一致。同时满足所有一致条件被认为是API映射一致。
以下例外场景也被认为是API映射一致,
例外场景1:相较于API映射一致标准,仅API参数的输入数据类型支持范围不一样,包含以下3种子场景:
(1)MindSpore的API支持传入int,float,bool等类型的参数,但不支持传入int8或float64等小位宽数据类型的参数。
(2)MindSpore的API不支持传入复数类型的参数。
例外场景2:相较于MindSpore的API,PyTorch的API多出的参数是通用差异参数。通用差异参数存在的原因是PyTorch有部分参数是为性能优化等非功能性而增加的参数,MindSpore的性能优化机制与PyTorch不同。
例外场景3:如果能保证MindSpore的API在使用默认配置(或用户不配置)的情况下,能够实现与PyTorch对应API完全一致的功能,则MindSpore的API多于PyTorch的API的参数,功能不被认为是差异。
例外场景4:MindSpore将API中与PyTorch重载机制相关参数的默认值设置为None,PyTorch对应API的相应参数无默认值。
下面是例外场景4的举例, 在PyTorch 1.8.1中,torch.argmax具有两种API重载形式,分别是torch.argmax(input)和torch.argmax(input, dim, keepdim=False),其中torch.argmax(input)返回输入Tensor中的最大值元素的索引,torch.argmax(input, dim, keepdim=False)返回输入Tensor在指定轴上最大值的索引。
mindspore.ops.argmax只有一种API形式,即mindspore.ops.argmax(input, dim=None, keepdim=False),但mindspore.ops.argmax(input)与torch.argmax(input)功能相同,mindspore.ops.argmax(input, dim, keepdim)与torch.argmax(input, dim, keepdim)功能相同。相较于torch.argmax,mindspore.ops.argmax参数dim的默认值被设置为None,仅是为了适配torch.argmax的两种API重载形式,因此例外场景4也被认为是API映射一致。
因为框架机制原因,MindSpore不提供PyTorch的以下参数:
参数名 | 功能 | 说明 |
---|---|---|
out | 表示输出的Tensor | 把运算结果赋值给out参数,MindSpore目前无此机制 |
layout | 表示内存分布策略 | PyTorch支持torch.strided和torch.sparse_coo两种模式, MindSpore目前无此机制 |
device | 表示Tensor存放位置 | 包含设备类型及可选设备号,MindSpore目前支持算子或网络级别的设备调度 |
requires_grad | 表示是否更新梯度 | MindSpore中可以通过Parameter.requires_grad 控制 |
generator | 表示伪随机数生成器 | MindSpore中通过随机数API的seed参数进行控制 |
pin_memory | 表示是否使用锁页内存 | MindSpore目前无此机制 |
memory_format | 表示Tensor的内存格式 | MindSpore目前无此机制 |
stable | 表示是否稳定排序 | 一般用在排序算法的API中,MindSpore目前无此功能 |
inplace | 表示在不更改变量内存地址的情况下,直接修改变量的值 | MindSpore目前提供少量inplace的API,例如assign_add 等 |
sparse_grad | 表示是否对梯度稀疏化 | MindSpore目前无此机制 |
size_average | PyTorch废弃参数 | MindSpore中可以使用reduction 参数替代 |
reduce | PyTorch废弃参数 | MindSpore中可以使用reduction 参数替代 |
PyTorch 1.12 APIs | MindSpore APIs | 说明 |
---|---|---|
torch.hsplit | mindspore.ops.hsplit | 一致 |
torch.permute | mindspore.ops.permute | 一致 |
torch.vsplit | mindspore.ops.vsplit | 一致 |
PyTorch 1.8.1 APIs | MindSpore APIs | 说明 |
---|---|---|
torch.distributions.laplace.Laplace | mindspore.ops.standard_laplace | 差异对比 |
PyTorch 1.8.1 APIs | MindSpore APIs | 说明 |
---|---|---|
torch.nn.Module.apply | mindspore.nn.Cell.apply | 一致 |
PyTorch 1.8.1 APIs | MindSpore APIs | 说明 |
---|---|---|
torch.nn.utils.clip_grad_value_ | mindspore.ops.clip_by_value | 差异对比 |
torch.nn.utils.clip_grad_norm_ | mindspore.ops.clip_by_norm | 差异对比 |
PyTorch 1.12 APIs | MindSpore APIs | 说明 |
---|---|---|
torch.Tensor.hsplit | mindspore.Tensor.hsplit | 功能一致,参数名不同 |
torch.Tensor.vsplit | mindspore.Tensor.vsplit | 功能一致,参数名不同 |
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。