2.4K Star 8.2K Fork 4.4K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
concat_doc.yaml 2.38 KB
一键复制 编辑 原始数据 按行查看 历史
chenfei_mindspore 提交于 2024-02-26 16:24 . move doc.yaml to doc dir
concat:
description: |
Connect input tensors along with the given axis.
The input data is a tuple or a list of tensors. These tensors have the same rank :math:`R`.
Set the given axis as :math:`m`, and :math:`0 \le m < R`. Set the number of input tensors as :math:`N`.
For the :math:`i`-th tensor :math:`t_i`, it has the shape of :math:`(x_1, x_2, ..., x_{mi}, ..., x_R)`.
:math:`x_{mi}` is the :math:`m`-th dimension of the :math:`t_i`. Then, the shape of the output tensor is
.. math::
(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)
Args:
tensors (Union[tuple, list]): A tuple or a list of input tensors.
Suppose there are two tensors in this tuple or list, namely t1 and t2.
To perform `concat` in the axis 0 direction, except for the :math:`0`-th axis,
all other dimensions should be equal, that is,
:math:`t1.shape[1] = t2.shape[1], t1.shape[2] = t2.shape[2], ..., t1.shape[R-1] = t2.shape[R-1]`,
where :math:`R` represents the rank of tensor.
axis (int): The specified axis, whose value is in range :math:`[-R, R)`. Default: ``0`` .
Returns:
Tensor, the shape is :math:`(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)`.
The data type is the same with `tensors`.
Raises:
TypeError: If `axis` is not an int.
ValueError: If `tensors` have different dimension of tensor.
ValueError: If `axis` not in range :math:`[-R, R)`.
ValueError: If tensor's shape in `tensors` except for `axis` are different.
ValueError: If `tensors` is an empty tuple or list.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> input_x1 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
>>> input_x2 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
>>> output = ops.cat((input_x1, input_x2))
>>> print(output)
[[0. 1.]
[2. 1.]
[0. 1.]
[2. 1.]]
>>> output = ops.cat((input_x1, input_x2), 1)
>>> print(output)
[[0. 1. 0. 1.]
[2. 1. 2. 1.]]
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
master

搜索帮助