2.6K Star 8.6K Fork 4.8K

GVPMindSpore/mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
mindspore.ops.TopK.rst 2.10 KB
一键复制 编辑 原始数据 按行查看 历史

mindspore.ops.TopK

.. py:class:: mindspore.ops.TopK(sorted=True)

    沿最后一个维度查找 `k` 个最大元素和对应的索引。

    .. warning::
        - 如果 `sorted` 设置为 ``False`` ,它将使用aicpu运算符,性能可能会降低,另外,由于在不同平台上存在内存排布以及遍历方式不同等问题,`sorted` 设置为 ``False`` 时计算结果的显示顺序可能会出现不一致的情况。

    如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大元素,并将其值和索引输出为Tensor。`values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。

    对于多维矩阵,计算每行中最大的 `k` 个元素(沿最后一个维度的相应向量),因此:

    .. math::
        values.shape = indices.shape = input.shape[:-1] + [k].

    如果两个比较的元素相同,则优先返回索引值较小的元素。

    参数:
        - **sorted** (bool,可选) - 如果为 ``True`` ,则获取的元素将按值降序排序。如果为 ``False`` ,则不对获取的元素进行排序。默认值: ``True`` 。

    输入:
        - **input_x** (Tensor) - 需计算的输入,目前GPU支持零维输入,但是Ascend或者CPU不支持。支持的数据类型:

          - Ascend:int8、uint8、int32、int64、float16、float32。
          - GPU:float16、float32。
          - CPU:所有数值型。

        - **k** (Union(Tensor, int)) - 指定计算最大元素的数量。若 `k` 为Tensor,其数据类型须为int32。若为Tensor,只支持零维Tensor或shape为 :math:`(1, )` 的一维Tensor。

    输出:
        由 `values` 和 `indices` 组成的tuple。

        - **values** (Tensor) - 最后一个维度的每个切片中的 `k` 最大元素。
        - **indices** (Tensor) - `k` 最大元素的对应索引。

    异常:
        - **TypeError** - 如果 `sorted` 不是bool。
        - **TypeError** - 如果 `input_x` 不是Tensor。
        - **TypeError** - 如果 `k` 不是int。
        - **TypeError** - 如果 `input_x` 的数据类型不被支持。
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r2.2

搜索帮助