1 Star 0 Fork 0

陈站 / xformers

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
global_tokens.py 4.00 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
)
from xformers.components.attention.attention_patterns import (
causal_1d_pattern,
global_token_pattern,
)
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class GlobalAttentionConfig(AttentionConfig):
attention_query_mask: torch.Tensor # Mark the queries which have global attention
causal: Optional[bool]
force_sparsity: Optional[bool]
@register_attention("global", GlobalAttentionConfig)
class GlobalAttention(Attention):
def __init__(
self,
dropout: float,
attention_query_mask: torch.Tensor,
causal: bool = False,
force_sparsity: bool = False,
*_,
**__,
):
r"""
Global attention, as proposed for instance in BigBird_ or Longformer_.
Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend
to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to
any other query.
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
Args:
dropout (float): probability of an element to be zeroed
attention_query_mask (torch.Tensor): if true, this query can attend to all the others
"""
super().__init__()
assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected"
assert (
attention_query_mask.shape[1] == 1
and attention_query_mask.shape[0] > attention_query_mask.shape[1]
), "A N x 1 query mask is expected"
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.attention_mask = global_token_pattern(attention_query_mask[:, 0])
self.force_sparsity = force_sparsity
if causal:
self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1])
self.attention_mask = (
sparsify(self.attention_mask)
if self.force_sparsity
else maybe_sparsify(self.attention_mask)
)
# Properties specific to this attention mechanism
self.requires_same_k_q_dimensions = True
self.supports_attention_mask = False
self.supports_key_padding_mask = False
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*_,
**__,
):
# Make sure that the mask is on the right device
if self.attention_mask.device != q.device:
self.attention_mask = self.attention_mask.to(q.device)
# Mask-aware attention
if att_mask is not None:
if att_mask.dtype == torch.bool and isinstance(
self.attention_mask, AttentionMask
):
if not isinstance(att_mask, AttentionMask):
att_mask = AttentionMask.from_bool(att_mask)
mask = self.attention_mask + att_mask
else:
mask = self.attention_mask & att_mask
else:
mask = self.attention_mask
# Handle q/k/v which would not fit the mask
seq_len = q.shape[-2]
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
# Normal attention with the global tokens mask
att = scaled_dot_product_attention(
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
)
# Take into account an hypothetical padding
return att[:, :seq_len, :]
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chzhan/xformers.git
git@gitee.com:chzhan/xformers.git
chzhan
xformers
xformers
main

搜索帮助

344bd9b3 5694891 D2dac590 5694891