代码拉取完成,页面将自动刷新
.. py:class:: mindspore.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=mstype.float32) 论文 `Attention Is All You Need <https://arxiv.org/pdf/1706.03762v5.pdf>`_ 中所述的多头注意力的实现。给定query向量,key向量和value向量,注意力计算流程如下: .. math:: MultiHeadAttention(query, key, value) = Concat(head_1, \dots, head_h)W^O 其中, :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)` , :math:`W^O` 、 :math:`W_i^Q` 、 :math:`W_i^K` 、 :math:`W_i^V` 是权重矩阵。注意:默认输入、输出投影层中带有偏置参数。 如果query、key和value相同,则上述即为自注意力机制的计算过程。 参数: - **embed_dim** (int) - 模型的总维数。 - **num_heads** (int) - 并行注意力头的数量。`num_heads` 需要能够被 `embed_dim` 整除(每个头的维数为 `embed_dim // num_heads`)。 - **dropout** (float) - 应用到输入 `attn_output_weights` 上的随机丢弃比例. 默认值: ``0.0``。 - **has_bias** (bool) - 是否给输入、输出投射层添加偏置。默认值: ``True``。 - **add_bias_kv** (bool) - 是否给key、value序列的零维添加偏置。默认值: ``False``。 - **add_zero_attn** (bool) - 是否给key、value序列的一维添加0。默认值: ``False``。 - **kdim** (int) - key的总特征数。默认值: ``None`` (即 `kdim=embed_dim`)。 - **vdim** (int) - value的总特征数。默认值:``None`` (即 `vdim=embed_dim`)。 - **batch_first** (bool) - 如果为 ``True``,则输入输出Tensor的shape为 :math:`(batch, seq, feature)` ,否则shape为 :math:`(seq, batch, feature)` 。 默认值: ``False`` 。 - **dtype** (:class:`mindspore.dtype`) - Parameter的数据类型。默认值: ``mstype.float32`` 。 输入: - **query** (Tensor) - Query矩阵。当输入非Batch数据时,shape为: :math:`(L, E_q)` 。当输入Batch数据,参数 `batch_first=False` 时,shape为 :math:`(L, N, E_q)` , 当 `batch_first=True` 时,shape为 :math:`(N, L, E_q)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size,:math:`E_q` 为Query矩阵的维数 `embed_dim`。 数据类型:float16、float32或者float64。注意力机制通过Query与Key-Value运算以生成最终输出。 - **key** (Tensor) - Key矩阵。当输入非Batch数据时,shape为: :math:`(S, E_k)` 。当输入Batch数据,参数 `batch_first=False` 时,shape为 :math:`(S, N, E_k)` , 当 `batch_first=True` 时,shape为 :math:`(N, S, E_k)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size,:math:`E_k` 为Key矩阵的维数 `kdim`。数据类型:float16、float32或者float64。 - **value** (Tensor) - Value矩阵。当输入非Batch数据时,shape为: :math:`(S, E_v)` 。当输入Batch数据,参数 `batch_first=False` 时,shape为 :math:`(S, N, E_v)` , 当 `batch_first=True` 时,shape为 :math:`(N, S, E_v)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size,:math:`E_v` 为Key矩阵的维数 `vdim`。数据类型:float16、float32或者float64。 - **key_padding_mask** (Tensor, 可选) - 如果指定此值,则表示shape为 :math:`(N, S)`的掩码将被用于 `key`。当输入非Batch数据时,shape为: :math:`(S)` 。支持Bool和float类型。 如果输入Tensor为Bool类型,则 `key` 中对应为 ``True`` 的位置将在Attention计算时被忽略。如果输入Tensor为float类型,则将直接与 `key` 相加。float支持数据类型:float16、float32或者float64。默认值:``None``。 - **need_weights** (bool) - 是否需要返回 `attn_output_weights`,如果为 ``True``,则输出包含 `attn_output_weights`。默认值:``True``。 - **attn_mask** (Tensor, 可选) - 如果指定此值,则表示shape为 :math:`(L, S)` 或 :math:`(N\cdot\text{num_heads}, L, S)` 的掩码将被用于Attention计算。其中 :math:`N` 为batch size, :math:`L` 为目标序列长度,:math:`S` 为源序列长度。如果输入为二维矩阵,则将自动沿batch维广播至三维矩阵。若为三维矩阵,则允许沿batch维使用不同的掩码。如果输入Tensor为Bool类型,则值为 ``True`` 对应位置允许被注意力计算。如果输入Tensor为float类型,则将直接与注意力权重相加。float支持数据类型:float16、float32或者float64。默认值:``None``。 - **average_attn_weights** (bool) - 如果为 ``True``, 则返回值 `attn_weights` 为注意力头的平均值。如果为 ``False``,则 ``attn_weights`` 分别返回每个注意力头的值。 本参数仅在 `need_weights=True` 时生效。默认值: ``True`` 。 输出: Tuple,表示一个包含(`attn_output`, `attn_output_weights`)的元组。 - **attn_output** - 注意力机制的输出。当输入非Batch数据时,shape为: :math:`(L, E)` 。当输入Batch数据, 参数 `batch_first=False` 时,shape为 :math:`(L, N, E)` , 当 `batch_first=True` 时,shape为 :math:`(N, L, E)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size, :math:`E` 为模型的总维数 `embed_dim`。 - **attn_output_weights** - 仅当 ``need_weights=True`` 时返回。如果 `average_attn_weights=True`,则返回值 `attn_weights` 为注意力头的平均值。当输入非Batch数据时, shape为: :math:`(L, S)` ,当输入Batch数据时,shape为 :math:`(N, L, S)`。其中 :math:`N` 为batch size, :math:`L` 为目标序列的长度,:math:`S` 为源序列长度。 如果 `average_attn_weights=False` ,分别返回每个注意力头的值。当输入非Batch数据时,shape为: :math:`(\text{num_heads}, L, S)` ,当输入Batch数据时,shape为 :math:`(N, \text{num_heads}, L, S)`。 异常: - **ValueError** - 如果 `embed_dim` 不能被 `num_heads` 整除。 - **TypeError** - 如果 `key_padding_mask` 不是bool或float类型。
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。