334 Star 1.5K Fork 864

MindSpore / docs

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
expand_dims.md 2.45 KB
一键复制 编辑 原始数据 按行查看 历史
luojianing 提交于 2023-07-21 15:16 . replace target=blank

Function Differences with tf.expand_dims

View Source On Gitee

tf.expand_dims

tf.expand_dims(x, axis, name=None) -> Tensor

For more information, see tf.expand_dims.

mindspore.ops.expand_dims

mindspore.ops.expand_dims(input_x, axis) -> Tensor

For more information, see mindspore.ops.expand_dims.

Differences

TensorFlow: Add an extra dimension to the input x on the given axis.

MindSpore: MindSpore API implements the same function as TensorFlow, and only the parameter names are different.

Categories Subcategories TensorFlow MindSpore Differences
Parameters Parameter 1 x input_x Same function, different parameter names
Parameter 2 axis axis -
Parameter 3 name - Not involved

Code Example 1

The two APIs achieve the same function and have the same usage.

# TensorFlow
import numpy as np
import tensorflow as tf

x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.float32)
axis = 1
out = tf.expand_dims (x, axis).numpy()
print(out)
# [[[ 1.  2.  3.  4.]]
#  [[ 5.  6.  7.  8.]]
#  [[ 9. 10. 11. 12.]]]

# MindSpore
import mindspore
import numpy as np
import mindspore.ops as ops
from mindspore import Tensor

input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
axis = 1
output = ops.expand_dims(input_params,  axis)
print(output)
# [[[ 1.  2.  3.  4.]]
#  [[ 5.  6.  7.  8.]]
#  [[ 9. 10. 11. 12.]]]

Code Example 2

The two APIs achieve the same function and have the same usage.

# TensorFlow
import numpy as np
import tensorflow as tf

x = np.array([[1,1,1]], dtype=np.float32)
axis = 2
out = tf.expand_dims (x, axis).numpy()
print(out)
# [[[1.]
#   [1.]
#   [1.]]]


# MindSpore
import mindspore
import numpy as np
import mindspore.ops as ops
from mindspore import Tensor

input_params = Tensor(np.array([[1,1,1]]), mindspore.float32)
axis = 2
output = ops.expand_dims(input_params,  axis)
print(output)
# [[[1.]
#   [1.]
#   [1.]]]
1
https://gitee.com/mindspore/docs.git
git@gitee.com:mindspore/docs.git
mindspore
docs
docs
r2.0

搜索帮助