代码拉取完成,页面将自动刷新
.. py:class:: mindsponge.cell.GlobalAttention(num_head, gating, input_dim, output_dim, batch_size=None)
global gated自注意力机制,具体实现请参考 `Highly accurate protein structure prediction with AlphaFold <https://www.nature.com/articles/s41586-021-03819-2>`_ 。对于GlobalAttention模块,query/key/value tensor的shape需保持一致。
参数:
- **num_head** (int) - 头的数量。
- **gating** (bool) - 判断attention是否经过gating的指示器。
- **input_dim** (int) - 输入的最后一维的长度。
- **output_dim** (int) - 输出的最后一维的长度。
- **batch_size** (int) - attention中权重的batch size,仅在有while控制流时使用,默认值: ``None``。
输入:
- **q_data** (Tensor) - shape为 :math:`(batch\_size, seq\_length, input\_dim)` 的query tensor,其中seq_length是query向量的序列长度。
- **m_data** (Tensor) - shape为 :math:`(batch\_size, seq\_length, input\_dim)` 的key和value tensor。
- **q_mask** (Tensor) - shape为 :math:`(batch\_size, seq\_length, 1)` 的q_data的mask。
- **bias** (Tensor) - attention矩阵的偏置。默认值: ``None``。
- **index** (Tensor) - 在while循环中的索引,仅在有while控制流时使用。默认值: ``None``。
输出:
Tensor。GlobalAttention层的输出tensor,shape是 :math:`(batch\_size, seq\_length, output\_dim)`。
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。