2.6K Star 8.6K Fork 4.8K

GVPMindSpore/mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
mindspore.nn.MultiheadAttention.rst 6.24 KB
一键复制 编辑 原始数据 按行查看 历史
宦晓玲 提交于 2024-02-21 15:46 . modify the dtype 2.3

mindspore.nn.MultiheadAttention

.. py:class:: mindspore.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=mstype.float32)

    论文 `Attention Is All You Need <https://arxiv.org/pdf/1706.03762v5.pdf>`_ 中所述的多头注意力的实现。给定query向量,key向量和value向量,注意力计算流程如下:

    .. math::
        MultiHeadAttention(query, key, value) = Concat(head_1, \dots, head_h)W^O

    其中, :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)` , :math:`W^O` 、 :math:`W_i^Q` 、 :math:`W_i^K` 、 :math:`W_i^V` 是权重矩阵。注意:默认输入、输出投影层中带有偏置参数。

    如果query、key和value相同,则上述即为自注意力机制的计算过程。

    参数:
        - **embed_dim** (int) - 模型的总维数。
        - **num_heads** (int) - 并行注意力头的数量。`num_heads` 需要能够被 `embed_dim` 整除(每个头的维数为 `embed_dim // num_heads`)。
        - **dropout** (float) - 应用到输入 `attn_output_weights` 上的随机丢弃比例. 默认值: ``0.0``。
        - **has_bias** (bool) - 是否给输入、输出投射层添加偏置。默认值: ``True``。
        - **add_bias_kv** (bool) - 是否给key、value序列的零维添加偏置。默认值: ``False``。
        - **add_zero_attn** (bool) - 是否给key、value序列的一维添加0。默认值: ``False``。
        - **kdim** (int) - key的总特征数。默认值: ``None`` (即 `kdim=embed_dim`)。
        - **vdim** (int) - value的总特征数。默认值:``None`` (即 `vdim=embed_dim`)。
        - **batch_first** (bool) - 如果为 ``True``,则输入输出Tensor的shape为 :math:`(batch, seq, feature)` ,否则shape为 :math:`(seq, batch, feature)` 。 默认值: ``False`` 。
        - **dtype** (:class:`mindspore.dtype`) - Parameter的数据类型。默认值: ``mstype.float32`` 。

    输入:
        - **query** (Tensor) - Query矩阵。当输入非Batch数据时,shape为: :math:`(L, E_q)` 。当输入Batch数据,参数 `batch_first=False` 时,shape为 :math:`(L, N, E_q)` ,
          当 `batch_first=True` 时,shape为 :math:`(N, L, E_q)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size,:math:`E_q` 为Query矩阵的维数 `embed_dim`。
          数据类型:float16、float32或者float64。注意力机制通过Query与Key-Value运算以生成最终输出。
        - **key** (Tensor) - Key矩阵。当输入非Batch数据时,shape为: :math:`(S, E_k)` 。当输入Batch数据,参数 `batch_first=False` 时,shape为 :math:`(S, N, E_k)` ,
          当 `batch_first=True` 时,shape为 :math:`(N, S, E_k)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size,:math:`E_k` 为Key矩阵的维数 `kdim`。数据类型:float16、float32或者float64。
        - **value** (Tensor) - Value矩阵。当输入非Batch数据时,shape为: :math:`(S, E_v)` 。当输入Batch数据,参数 `batch_first=False` 时,shape为 :math:`(S, N, E_v)` ,
          当 `batch_first=True` 时,shape为 :math:`(N, S, E_v)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size,:math:`E_v` 为Key矩阵的维数 `vdim`。数据类型:float16、float32或者float64。
        - **key_padding_mask** (Tensor, 可选) - 如果指定此值,则表示shape为 :math:`(N, S)`的掩码将被用于 `key`。当输入非Batch数据时,shape为: :math:`(S)` 。支持Bool和float类型。
          如果输入Tensor为Bool类型,则 `key` 中对应为 ``True`` 的位置将在Attention计算时被忽略。如果输入Tensor为float类型,则将直接与 `key` 相加。float支持数据类型:float16、float32或者float64。默认值:``None``。
        - **need_weights** (bool) - 是否需要返回 `attn_output_weights`,如果为 ``True``,则输出包含 `attn_output_weights`。默认值:``True``。
        - **attn_mask** (Tensor, 可选) - 如果指定此值,则表示shape为 :math:`(L, S)` 或 :math:`(N\cdot\text{num_heads}, L, S)` 的掩码将被用于Attention计算。其中 :math:`N` 为batch size,
          :math:`L` 为目标序列长度,:math:`S` 为源序列长度。如果输入为二维矩阵,则将自动沿batch维广播至三维矩阵。若为三维矩阵,则允许沿batch维使用不同的掩码。如果输入Tensor为Bool类型,则值为 ``True`` 对应位置允许被注意力计算。如果输入Tensor为float类型,则将直接与注意力权重相加。float支持数据类型:float16、float32或者float64。默认值:``None``。
        - **average_attn_weights** (bool) - 如果为 ``True``, 则返回值 `attn_weights` 为注意力头的平均值。如果为 ``False``,则 ``attn_weights`` 分别返回每个注意力头的值。
          本参数仅在 `need_weights=True` 时生效。默认值: ``True`` 。

    输出:
        Tuple,表示一个包含(`attn_output`, `attn_output_weights`)的元组。

        - **attn_output** - 注意力机制的输出。当输入非Batch数据时,shape为: :math:`(L, E)` 。当输入Batch数据, 参数 `batch_first=False` 时,shape为 :math:`(L, N, E)` ,
          当 `batch_first=True` 时,shape为 :math:`(N, L, E)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size, :math:`E` 为模型的总维数 `embed_dim`。
        - **attn_output_weights** - 仅当 ``need_weights=True`` 时返回。如果 `average_attn_weights=True`,则返回值 `attn_weights` 为注意力头的平均值。当输入非Batch数据时,
          shape为: :math:`(L, S)` ,当输入Batch数据时,shape为 :math:`(N, L, S)`。其中 :math:`N` 为batch size, :math:`L` 为目标序列的长度,:math:`S` 为源序列长度。
          如果 `average_attn_weights=False` ,分别返回每个注意力头的值。当输入非Batch数据时,shape为: :math:`(\text{num_heads}, L, S)` ,当输入Batch数据时,shape为
          :math:`(N, \text{num_heads}, L, S)`。

    异常:
        - **ValueError** - 如果 `embed_dim` 不能被 `num_heads` 整除。
        - **TypeError** - 如果 `key_padding_mask` 不是bool或float类型。
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r2.3.1

搜索帮助