diff --git a/omni/layers/attention/backend/gdn_attn.py b/omni/layers/attention/backend/gdn_attn.py index bb3d37ba19425d6ca64f33d13eb55c45b90181ad..b03d888d69f6efd34f034c749f863cdd8c89a127 100644 --- a/omni/layers/attention/backend/gdn_attn.py +++ b/omni/layers/attention/backend/gdn_attn.py @@ -26,20 +26,8 @@ class GDNAttentionBackend(AttentionBackend): @dataclass class GDNAttentionMetadata: - block_tables: torch.Tensor # not used - query_lens: torch.Tensor # not used - query_lens_list: List # not used - seq_lens: torch.Tensor # not used - seq_lens_list: List # not used - max_query_len: Optional[int] # not used - slot_mapping: torch.Tensor # not used - slot_indices: torch.Tensor # not used - is_only_prefill: bool # not used - attn_state: AscendAttentionState # not used - cos: Optional[torch.Tensor] # not used - sin: Optional[torch.Tensor] # not used - is_pd_seperate_d: bool # not used - kv_index: Optional[torch.Tensor] # not used + attn_state: AscendAttentionState + is_pd_seperate_d: bool num_prefills: int num_prefill_tokens: int @@ -305,31 +293,13 @@ class GDNAttentionMetadataBuilder(AscendAttentionMetadataBuilder): non_spec_query_start_loc[num_decodes + 1:].fill_(non_spec_num_query_tokens) - ascend_attn_metadata = super().build( - num_reqs=num_reqs, - num_actual_tokens=num_actual_tokens, - max_query_len=max_query_len, - common_prefix_len=None, - graph_pad_size=graph_pad_size, - **kwargs, - ) + attn_state = self.runner.attn_state + is_pd_seperate_d = self.runner.vllm_config.kv_transfer_config is not None and \ + self.runner.vllm_config.kv_transfer_config.kv_role == 'kv_consumer' attn_metadata = GDNAttentionMetadata( - block_tables=ascend_attn_metadata.block_tables, - query_lens=ascend_attn_metadata.query_lens, - query_lens_list=ascend_attn_metadata.query_lens_list, - seq_lens=ascend_attn_metadata.seq_lens, - seq_lens_list=ascend_attn_metadata.seq_lens_list, - max_query_len=ascend_attn_metadata.max_query_len, - slot_mapping=ascend_attn_metadata.slot_mapping, - slot_indices=ascend_attn_metadata.slot_indices, - is_only_prefill=ascend_attn_metadata.is_only_prefill, - attn_state=ascend_attn_metadata.attn_state, - cos=ascend_attn_metadata.cos, - sin=ascend_attn_metadata.sin, - is_pd_seperate_d=ascend_attn_metadata.is_pd_seperate_d, - kv_index=ascend_attn_metadata.kv_index, - # up to here, just copy AscendMetadata and not used. + attn_state=attn_state, + is_pd_seperate_d=is_pd_seperate_d, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, @@ -373,58 +343,19 @@ class GDNAttentionMetadataBuilder(AscendAttentionMetadataBuilder): dtype=self.runner.input_batch.block_table[0].get_device_tensor(max_pad_size).dtype ) - query_lens = torch.ones(max_pad_size, dtype=torch.long, device=self.runner.device, pin_memory=True) - seq_lens = query_lens * 2 - - slot_indices = torch.stack([slot_mapping // self.block_size, slot_mapping % self.block_size], dim=1) - - fake_positions = torch.zeros(max_pad_size, dtype=torch.int64, device=self.device) - - cos, sin = None, None - is_pd_seperate_d = self.runner.vllm_config.kv_transfer_config is not None and \ self.runner.vllm_config.kv_transfer_config.kv_role == 'kv_consumer' non_spec_query_start_loc = torch.tensor([0, 1], device=self.runner.device) non_spec_state_indices_tensor = self.block_table.block_table[:, 0] - ascend_attn_metadata = AscendMetadata( - num_actual_tokens=num_tokens, - block_tables=block_table, - query_lens=query_lens, - query_lens_list=query_lens.tolist(), - seq_lens=seq_lens, - seq_lens_list=seq_lens.tolist(), - slot_mapping=slot_mapping, - slot_indices=slot_indices, - is_only_prefill=False, - attn_state=self.runner.attn_state, - cos=cos, - sin=sin, - is_pd_seperate_d=is_pd_seperate_d - ) - - attn_metadata = GDNAttentionMetadata( - block_tables=ascend_attn_metadata.block_tables, - query_lens=ascend_attn_metadata.query_lens, - query_lens_list=ascend_attn_metadata.query_lens_list, - seq_lens=ascend_attn_metadata.seq_lens, - seq_lens_list=ascend_attn_metadata.seq_lens_list, - max_query_len=ascend_attn_metadata.max_query_len, - slot_mapping=ascend_attn_metadata.slot_mapping, - slot_indices=ascend_attn_metadata.slot_indices, - is_only_prefill=ascend_attn_metadata.is_only_prefill, - attn_state=ascend_attn_metadata.attn_state, - cos=ascend_attn_metadata.cos, - sin=ascend_attn_metadata.sin, - is_pd_seperate_d=ascend_attn_metadata.is_pd_seperate_d, - kv_index=ascend_attn_metadata.kv_index, - # up to here, just copy AscendMetadata + attn_state=self.runner.attn_state, + is_pd_seperate_d=is_pd_seperate_d, num_prefills=0, num_prefill_tokens=0, num_decodes=num_tokens, - is_decode=(num_decodes > 0), + is_decode=(num_tokens > 0), num_decode_tokens=num_tokens, num_spec_decodes=0, num_spec_decode_tokens=0,