From 03381827b7cabba9d195d850791b4bdee3d6c03e Mon Sep 17 00:00:00 2001 From: Nickyi Date: Mon, 10 Nov 2025 17:29:27 +0800 Subject: [PATCH 1/6] adapt pangu_pro_moe_v2 --- .vscode/settings.json | 3 + omni/adaptors/vllm/worker/npu_model_runner.py | 5 + omni/layers/attention/backend/attention.py | 367 ++++- omni/models/__init__.py | 4 + .../pangu/pangu_pro_moe_v2/fused_moe.py | 762 ++++++++++ .../pangu/pangu_pro_moe_v2/pangu_moe_v2.py | 1341 +++++++++++++++++ tools/scripts/send_http.py | 41 + .../test_start_api_servers_pangu_72Bv2.sh | 29 + 8 files changed, 2550 insertions(+), 2 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 omni/models/pangu/pangu_pro_moe_v2/fused_moe.py create mode 100644 omni/models/pangu/pangu_pro_moe_v2/pangu_moe_v2.py create mode 100644 tools/scripts/send_http.py create mode 100644 tools/scripts/test_start_api_servers_pangu_72Bv2.sh diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..24e2c93c6 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "workbench.editor.wrapTabs": true +} \ No newline at end of file diff --git a/omni/adaptors/vllm/worker/npu_model_runner.py b/omni/adaptors/vllm/worker/npu_model_runner.py index 479aa7d39..b6eb80c31 100644 --- a/omni/adaptors/vllm/worker/npu_model_runner.py +++ b/omni/adaptors/vllm/worker/npu_model_runner.py @@ -1270,6 +1270,11 @@ class NPUModelRunner(GPUModelRunner): self.device, self.model_config, self.enable_torchair_graph_mode) + hf_config = self.vllm_config.model_config.hf_config + v_channels = getattr(hf_config, "v_channels", None) + if v_channels is not None: + kv_caches[layer_name] = (kv_caches[layer_name][0], kv_caches[layer_name][1][...,:v_channels].contiguous()) + if preemption_mode and preemption_mode == "swap": cpu_num_blocks = int(self.vllm_config.cache_config.swap_space_bytes // kv_cache_spec.page_size_bytes // len(kv_cache_config.tensors)) diff --git a/omni/layers/attention/backend/attention.py b/omni/layers/attention/backend/attention.py index 8dc30403b..c87794155 100644 --- a/omni/layers/attention/backend/attention.py +++ b/omni/layers/attention/backend/attention.py @@ -26,7 +26,7 @@ import torchair as tng from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type - +import torch.nn.functional as F from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState @@ -58,6 +58,10 @@ def unified_ascend_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, + sink_query: Optional[torch.Tensor] = None, + sink_key: Optional[torch.Tensor] = None, + sink_value: Optional[torch.Tensor] = None, + v_head_size: Optional[int] = None, ) -> None: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -73,7 +77,11 @@ def unified_ascend_attention_with_output( kv_cache, attn_metadata, output, - trace_flag=False) + trace_flag=False, + sink_query=sink_query, + sink_key=sink_key, + sink_value=sink_value, + v_head_size=v_head_size) return @@ -501,10 +509,13 @@ class AscendAttentionBackendImpl(AttentionImpl): ) self.use_tnd_pa = model_extra_config.operator_opt_config.use_tnd_pa self.kv_stream = kv_stream + self.use_sink = getattr(cur_vllm_config.model_config.hf_config, "v_channels", None) is not None def forward(self, *args, **kwargs): if self.use_tnd_pa: return self.forward_pa(*args, **kwargs) + elif self.use_sink: + return self.forward_sink(*args, **kwargs) else: return self.forward_vanilla(*args, **kwargs) @@ -807,6 +818,358 @@ class AscendAttentionBackendImpl(AttentionImpl): return output.view(num_tokens, self.hidden_size) + def forward_sink( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + trace_flag: bool = True, + sink_query: Optional[torch.Tensor] = None, + sink_key: Optional[torch.Tensor] = None, + sink_value: Optional[torch.Tensor] = None, + v_head_size: Optional[int] = None, + ) -> torch.Tensor: + """Forward pass with Ascend attention. + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache: shape = [2, num_blocks, block_size, + num_kv_heads * head_size] + key_cache = [num_blocks, block_size, + num_kv_heads * head_size] + value_cache = [num_blocks, block_size, + num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size * seq_len, num_heads, head_size] + """ + num_tokens = query.shape[0] + if v_head_size == None: + v_head_size = self.head_size + if output is None: + output = torch.empty(num_tokens, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device) + # print("num_tokens", num_tokens) #2048 + # print("self.num_heads", self.num_heads) #16 + # print("self.head_size", self.head_size) #192 + # print("self.hidden_size", self.hidden_size) #3072 + if attn_metadata is None: + return output.view(num_tokens, -1) + + if not (layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0): + raise RuntimeError("layer._k_scale_float and layer._v_scale_float must both be 1.0") + attn_type = self.attn_type + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + # View q k v to BSH. + special_head_size_flag = (self.head_size == 192) + sink_key_flag = (sink_key is not None) + + # print("query.shape", query.shape) #[256, 16, 192] + # print("key.shape", key.shape) #[256, 1, 192] + # print("value.shape", value.shape) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + if sink_key_flag: + value = value.view(-1, self.num_kv_heads, v_head_size) # v has a different h + else: + value = value.view(-1, self.num_kv_heads, self.head_size) + value = value.contiguous() + + # update kv cache + if kv_cache[0].numel() > 0 or kv_cache[1].numel(): + + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + # print("self.value_cache.shape", self.value_cache.shape) + + block_size = self.key_cache.shape[1] + # kv_cache: shape = [2, num_blocks, block_size, + # num_kv_heads * head_size] + cast_key = key.reshape(-1, 1, self.num_kv_heads * self.head_size) + if sink_key_flag: + cast_value = value.reshape(-1, 1, self.num_kv_heads * v_head_size) + else: + cast_value = value.reshape(-1, 1, self.num_kv_heads * self.head_size) + + if attn_metadata.attn_state != AscendAttentionState.DecodeOnly: + # if prefill does not use paged attention, + # (1) saving keys and values into kv_cache, and + # (2) GQA + # can run simultaneously in two streams + if self.kv_stream is not None and not hasattr(layer, 'quant_method') and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + stream_for_reshape_and_cache = self.kv_stream + self.kv_stream.wait_stream(torch.npu.current_stream()) + else: + stream_for_reshape_and_cache = torch.npu.current_stream() + with torch.npu.stream(stream_for_reshape_and_cache): + # print("key.shape", key.shape) #[8, 1, 192] + # print("value.shape", value.shape) #[8, 1, 128] + # print("self.key_cache.view(self.key_cache.shape[0], block_size, self.num_kv_heads, self.head_size).shape", self.key_cache.view(self.key_cache.shape[0], block_size, self.num_kv_heads, self.head_size).shape) #[140, 128, 1, 192] + # print("self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, self.v_head_size).shape", self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, v_head_size).shape) #[140, 128, 1, 128] + # print("attn_metadata.slot_mapping.int().shape", attn_metadata.slot_mapping.int().shape) #[8] + torch_npu._npu_reshape_and_cache( + key, + value, + self.key_cache.view(self.key_cache.shape[0], block_size, self.num_kv_heads, self.head_size), + self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, v_head_size), + attn_metadata.slot_mapping.int() + ) + else: + # print("self.key_cache", self.key_cache.shape) #[140, 128, 192] + # print("attn_metadata.slot_indices", attn_metadata.slot_indices.shape) #[256, 2] + # print("cast_key", cast_key.shape) #[256, 1, 192] + # print("self.value_cache", self.value_cache.shape) + # print("attn_metadata.slot_indices", attn_metadata.slot_indices) + # print("cast_value", cast_value.shape) ##[256, 1, 128] + + if sink_key_flag: + torch_npu._npu_reshape_and_cache( + cast_key, + cast_value, + self.key_cache.view(self.key_cache.shape[0], block_size, self.num_kv_heads, self.head_size), + self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, v_head_size), + attn_metadata.slot_mapping.int() + ) + else: + torch_npu.scatter_update_(self.key_cache, attn_metadata.slot_indices, cast_key, -2) + torch_npu.scatter_update_(self.value_cache, attn_metadata.slot_indices, cast_value, -2) + + if sink_key_flag and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + # kv cache start from block 1 and slots 128, so we store sink in block 0. + slots = torch.arange(0, 128, device=sink_key.device, dtype=torch.int32) + bsz = attn_metadata.query_lens.shape[0] + # print("sink_key.shape", sink_key.shape) + # print("sink_value.shape", sink_value.shape) + + # print("self.key_cache.shape", self.key_cache.shape) + # print("self.value_cache.shape", self.value_cache.shape) + # print("slots.shape", slots.shape) + torch_npu._npu_reshape_and_cache( + key=sink_key, + value=sink_value, + key_cache=self.key_cache.view(self.key_cache.shape[0], block_size, self.num_kv_heads, self.head_size), + value_cache=self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, v_head_size), + slot_indices=slots) + + + if hasattr(layer, 'quant_method'): + pass + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if not (os.getenv("ENABLE_PREFILL_TND", "0") == "1"): + if attn_metadata is None: + raise RuntimeError("attn_metadata must not be None") + + if len(attn_metadata.query_lens_list) == 1: + attn_output = torch_npu.npu_fused_infer_attention_score( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BSND", + scale=self.scale, + sparse_mode=3, + actual_seq_lengths=attn_metadata.query_lens_list, + actual_seq_lengths_kv=attn_metadata.seq_lens_list, + atten_mask=AscendAttentionBackendImpl.SHARE_MASK_TRIL_SPARSE, + )[0].view(-1, self.num_heads, self.head_size) + # print("attn_output.shape", attn_output.shape) + # print("output.shape", output.shape) + + output = output.view_as(attn_output) + output.copy_(attn_output) + else: + actual_seq_qlen = np.array(attn_metadata.query_lens).cumsum().tolist() + actual_seq_kvlen = np.array(attn_metadata.seq_lens).cumsum().tolist() + + attn_output = torch_npu.npu_fusion_attention( + query[:actual_seq_qlen[-1], :], + key[:actual_seq_qlen[-1], :], + value[:actual_seq_qlen[-1], :], + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + atten_mask=AscendAttentionBackendImpl.SHARE_MASK_TRIL_SPARSE, + sparse_mode=3, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen)[0] + + output[:actual_seq_qlen[-1], :].copy_(attn_output) + else: + if attn_metadata is None: + raise RuntimeError("attn_metadata must not be None") + + + if sink_key_flag: + bsz = attn_metadata.query_lens.shape[0] + cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() + cu_seqlen_q = torch.tensor(cu_seqlen_q, device=key.device) + cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) + key_list = [] + value_list = [] + for i in range(bsz): + k = key[cu_seqlen_q[i]:cu_seqlen_q[i+1]] + v = value[cu_seqlen_q[i]:cu_seqlen_q[i+1]] + key_list.append(torch.cat([sink_key, k], dim=0)) + value_list.append(torch.cat([sink_value, v], dim=0)) + key = torch.cat(key_list, dim=0) + value = torch.cat(value_list, dim=0) + + + if special_head_size_flag: + atten_mask = ~torch.tril( + torch.ones((2048, 2048), device='npu', dtype=torch.bool) + ) + if sink_key_flag: + seq_lens_with_sink = attn_metadata.seq_lens + sink_key.shape[0] + cu_seqlen = [0] + (attn_metadata.seq_lens.tolist() if sink_key is None else seq_lens_with_sink.tolist()) + cu_seqlen = torch.tensor(cu_seqlen, device="npu") + cu_seqlen = torch.cumsum(cu_seqlen, dim=0)[1:] + attn_output = torch_npu.npu_fused_infer_attention_score( + query, + key, + value, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=3, + pre_tokens=2147483647, + next_tokens=0, + atten_mask=atten_mask, + inner_precise=0, + actual_seq_lengths=cu_seqlen_q[1:], + actual_seq_lengths_kv=cu_seqlen)[0] + output.copy_(attn_output) + output = output.view(-1, self.num_heads * v_head_size) + + + else: + actual_seq_qlen = np.array(attn_metadata.query_lens).cumsum().tolist() + actual_seq_kvlen = np.array(attn_metadata.seq_lens).cumsum().tolist() + attn_output = torch_npu.npu_fused_infer_attention_score( + query[:actual_seq_qlen[-1],:,:], + key[:actual_seq_qlen[-1],:,:], + value[:actual_seq_qlen[-1],:,:], + num_heads = self.num_heads, + num_key_value_heads = self.num_kv_heads, + input_layout = "TND", + scale = self.scale, + sparse_mode = 3, + actual_seq_lengths = actual_seq_qlen, + actual_seq_lengths_kv = actual_seq_kvlen, + atten_mask = AscendAttentionBackendImpl.SHARE_MASK_TRIL_SPARSE, + )[0].view(-1, self.num_heads, self.head_size) + + output[:actual_seq_qlen[-1], :].copy_(attn_output) + if stream_for_reshape_and_cache != torch.npu.current_stream(): + torch.npu.current_stream().wait_stream(stream_for_reshape_and_cache) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + + if special_head_size_flag: + if sink_key_flag: + # actual_block_tables = attn_metadata.block_tables + # actual_seq_lengths = attn_metadata.seq_lens + + block_size = self.value_cache.shape[1] + num_batch = attn_metadata.query_lens.shape[0] + # sink stored in block 0 + block_tables = F.pad(attn_metadata.block_tables, (1, 0, 0, 0), value=0) + + # PA模式actual_seq_lengths累加,actual_seq_lengths_kv不累加;非PA模式都是累加 + cu_seqlen_q = torch.arange(0, num_batch).npu() + 1 + attn_output = torch_npu.npu_fused_infer_attention_score( + query, + self.key_cache.view(-1, block_size, self.num_kv_heads * self.head_size), + self.value_cache.view(-1, block_size, self.num_kv_heads * v_head_size), + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", + scale=self.scale, + block_table=block_tables[:num_batch], + block_size=block_size, + actual_seq_lengths=cu_seqlen_q, + actual_seq_lengths_kv=attn_metadata.seq_lens + sink_key.shape[0], + )[0] + + output.copy_(attn_output) + else: + block_num, block_size = self.key_cache.shape[0], self.key_cache.shape[1] + + num_batch = attn_metadata.seq_lens.shape[0] + query = query.view(num_batch, -1, self.num_heads * self.head_size) + block_tables = attn_metadata.block_tables + attn_output = None + if self.enable_graph_mode: + attn_output, _ = tng.ops.npu_fused_infer_attention_score( + torch.transpose(query.view(num_batch, -1, self.num_heads, self.head_size), 1, 2), + self.key_cache, + self.value_cache, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BNSD", + scale=self.scale, + actual_seq_lengths_kv=attn_metadata.seq_lens, + block_table=block_tables, + block_size=block_size, + inner_precise=1 + ) + else: + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query, + self.key_cache, + self.value_cache, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BSH", + scale=self.scale, + actual_seq_lengths_kv=attn_metadata.seq_lens, + block_table=block_tables, + block_size=block_size, + ) + + output = output.view_as(attn_output) + output.copy_(attn_output) + + # Normal V1 situation. + else: + # use chunked prefill for head size 192 scenario, like deepseek + # paged_attention_splitfuse maybe crash at such scenario + + all_key = self.key_cache.view(-1, self.num_kv_heads, self.head_size)[attn_metadata.kv_index].contiguous() + all_value = self.value_cache.view(-1, self.num_kv_heads, self.head_size)[attn_metadata.kv_index].contiguous() + actual_seq_qlen = np.array(attn_metadata.query_lens).cumsum().tolist() + actual_seq_kvlen = np.array(attn_metadata.seq_lens).cumsum().tolist() + attn_output = torch_npu.npu_fusion_attention( + query, + all_key, + all_value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + atten_mask=AscendAttentionBackendImpl.SHARE_MASK_TRIL_SPARSE, + sparse_mode=3, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + )[0] + + output = output.view_as(attn_output) + output.copy_(attn_output) + + return output.view(num_tokens, -1) + # return output.view(num_tokens, self.hidden_size) class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True diff --git a/omni/models/__init__.py b/omni/models/__init__.py index 7efe18677..cbdf85593 100644 --- a/omni/models/__init__.py +++ b/omni/models/__init__.py @@ -84,6 +84,10 @@ def register_model(): ModelRegistry.register_model( "PanguProMoEForCausalLM", "omni.models.pangu.pangu_pro_moe.pangu_moe:PanguProMoEForCausalLM") + + ModelRegistry.register_model( + "PanguProMoEV2ForCausalLM", + "omni.models.pangu.pangu_pro_moe_v2.pangu_moe_v2:PanguProMoEV2ForCausalLM") ModelRegistry.register_model( "PanguEmbeddedForCausalLM", diff --git a/omni/models/pangu/pangu_pro_moe_v2/fused_moe.py b/omni/models/pangu/pangu_pro_moe_v2/fused_moe.py new file mode 100644 index 000000000..6a6dbe386 --- /dev/null +++ b/omni/models/pangu/pangu_pro_moe_v2/fused_moe.py @@ -0,0 +1,762 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from the vllm-ascend project to reuse its model components +# for omni-infer integration. +# Adapted from vllm/tests/kernels/test_moe.py + +from typing import Callable, Optional + +import torch +import torch_npu + +from vllm.logger import init_logger +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_world_size) +# from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +# from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) +from typing import Callable, Optional, Union +from vllm.config import ParallelConfig, get_current_vllm_config +from vllm.model_executor.layers.fused_moe.layer import ( + UnquantizedFusedMoEMethod, FusedMoE, FusedMoEMethodBase, + FusedMoEParallelConfig, MoEConfig, determine_expert_map) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.attention.layer import Attention +from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +logger = init_logger(__name__) + +PANGU_DEBUG = False +def debug_data(data, prefix=None, rank=0): + if not PANGU_DEBUG: + return + if isinstance(data, torch.Tensor): + suffix = f"{data} {data.dtype} {data.shape} {data.float().sum()}" + elif isinstance(data, list): + suffix = f"{data}" + elif isinstance(data, int): + suffix = f"{data}" + elif isinstance(data, str): + suffix = f"{data}" + else: + suffix = f"{data}" + local_rank = torch.distributed.get_rank() + if rank == -1 or rank == local_rank: + frame = inspect.stack()[1] + data_name = frame.code_context[0].split('(')[1].split(',')[0] + print(f"rank:{local_rank} file:{frame.filename} line:{frame.lineno} prefix:{prefix} {data_name} {suffix}", + flush=True) + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, + max_num_tokens: Optional[int] = None, +) -> torch.Tensor: + """ + Fused experts with top-k routing. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). + w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + top_k: Number of experts to select. + expert_map: Expert mapping of shape (num_experts,). + + Returns: + hidden_states: Hidden states after routing. + """ + """ + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + """ + + original_shape = hidden_states.shape + + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + dtype = hidden_states.dtype + device = hidden_states.device + + ep_size = get_ep_group().world_size + n_total_expert = num_experts * ep_size + local_group_start = get_ep_group().rank_in_group * num_experts + local_group_end = (get_ep_group().rank_in_group + 1) * num_experts + expert_range = [local_group_start, local_group_end] + + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert (topk == 1), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + + debug_data(expert_map, f"expert_map") + if expert_map is not None: + # Generate token indices and flatten + token_indices = (torch.arange(num_tokens, + device=device, + dtype=torch.int64).unsqueeze(1).expand( + -1, top_k).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = expert_map[experts_flat] + + # Filter valid token-expert pairs + mask = local_experts_flat != -1 + filtered_weights = torch.where( + mask, weights_flat, 0).to(dtype) + filtered_experts = torch.where( + mask, local_experts_flat, + torch.full_like(local_experts_flat, + num_experts)).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts.float()) + sorted_token_indices = token_indices[sort_indices] + sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(num_experts + 1, + device=device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) + token_counts = token_counts[:num_experts] + expert_token_count = token_counts # torch.cumsum(token_counts, dim=0, dtype=torch.int64) + + # Rearrange hidden_states + sorted_hidden_states = hidden_states[sorted_token_indices] + else: + sorted_hidden_states, expanded_row_idx, expert_token_count, _ = torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + scale=None, + active_num=topk_ids.numel(), + expert_capacity=-1, + expert_num=n_total_expert, + drop_pad_mode=0, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + quant_mode=-1, + active_expert_range=expert_range, + row_idx_type=0, + ) + + debug_data(sorted_hidden_states, f"sorted_hidden_states") + # debug_data(expanded_row_idx, f"expanded_row_idx") + debug_data(expert_token_count, f"expert_token_count") + + + + gate_up_out = torch_npu.npu_grouped_matmul( + x=[sorted_hidden_states], + weight=[w1.transpose(1, 2)], + split_item=2, + group_list_type=1, # 0, # 0:cusum 1:count + group_type=0, + group_list=expert_token_count, + )[0] + debug_data(gate_up_out, f"gate_up_out") + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + debug_data(gate_up_out, f"gate_up_out_after_swi") + down_out = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2.transpose(1, 2)], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_count, + )[0] + debug_data(down_out, f"down_out") + if expert_map is not None: + weighted_down_out = down_out * sorted_weights.unsqueeze(1) + + final_hidden_states = torch.zeros(*original_shape, + device=hidden_states.device, + dtype=dtype) + + # npu_grouped_matmul output random values at [num_valid_tokens:, ...] + # This created multiple NaN and index_add_ will mix them up which harms accuracy + # remove this mask and filter after it being fixed + num_valid_tokens = mask.sum() + valid_token_mask = (torch.arange( + 0, sorted_token_indices.shape[0], device=device).unsqueeze(1) < + num_valid_tokens) + valid_output = torch.where( + valid_token_mask, weighted_down_out, + 0).to(dtype) + final_hidden_states.index_add_(0, sorted_token_indices, valid_output) + else: + scales = (torch.ones_like(topk_weights) + if apply_router_weight_on_input else topk_weights) + final_hidden_states = torch_npu.npu_moe_finalize_routing( + down_out, + skip1=None, + skip2=None, + bias=None, + scales=scales, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + drop_pad_mode=2 + ) + group_list_type = 1 + return final_hidden_states, expert_token_count, group_list_type + + +def native_grouped_topk( + topk_weights: torch.Tensor, + num_expert_group: Optional[int], + topk_group: Optional[int], +): + topk_group = 0 if topk_group is None else topk_group + num_expert_group = 0 if num_expert_group is None else num_expert_group + + num_token = topk_weights.shape[0] + grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values + topk_group_indices = torch.topk(grouped_weights.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] + topk_group_mask = torch.zeros_like(grouped_weights) + topk_group_mask.scatter_(1, topk_group_indices, 1) + topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) + topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) + + return topk_weights + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scaling_factor: Optional[float] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Select top-k experts based on router logits. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + router_logits: Router logits of shape (num_tokens, num_experts). + top_k: Number of experts to select. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + + Returns: + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + + Raises: + ValueError: If an unsupported scoring function is provided. + """ + + def _renormalize_topk_weights( + topk_weights: torch.Tensor, + renormalize: bool, + ): + if renormalize: + topk_weights = topk_weights / (topk_weights.sum(dim=-1, + keepdim=True) + 1e-20) + return topk_weights + + if scoring_func == "softmax": + # vLLM use dtype float here + if not use_grouped_topk and custom_routing_function is None: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( + x=router_logits, finished=None, k=top_k) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + return topk_weights, topk_ids + + topk_weights = router_logits.softmax(dim=-1) + elif scoring_func == "sigmoid": + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk当前写8 + bias=e_score_correction_bias.to(torch.bfloat16), + k_group=1, # fix: 4 + group_count=1, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + routed_scaling_factor=scaling_factor, + eps=float(1e-20), + ) + return topk_weights, topk_ids + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_weights = topk_weights + topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) + + # will change to npu_group_topk when the latest CANN and NNAL is available + topk_weights = native_grouped_topk(topk_weights, num_expert_group, + topk_group) + # bfloat16 is not supported in torch.topk with ge graph. + if e_score_correction_bias is not None: + topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_weights.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False) + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + return topk_weights, topk_ids + + if custom_routing_function is not None: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts, + ) + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + return topk_weights, topk_ids + + if scoring_func == "sigmoid" and not use_grouped_topk: + if e_score_correction_bias is not None: + topk_weights_for_routing = topk_weights + e_score_correction_bias + _, topk_ids = torch.topk(topk_weights_for_routing, k=top_k, dim=1) + topk_weights = torch.gather(topk_weights, dim=1, index=topk_ids).type_as(router_logits) + else: + topk_weights, topk_ids = torch.topk(topk_weights, k=top_k, dim=1) + else: + topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) + + topk_weights = topk_weights.to(hidden_states.dtype) + + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + + if scaling_factor is not None: + topk_weights = topk_weights * scaling_factor + + return topk_weights, topk_ids + +# This function will be patched to UnquantizedFusedMoEMethod.forward_oot +def UnquantizedFusedMoEMethod_apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scaling_factor: Optional[float] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", +) -> torch.Tensor: + return self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + # patch for pangu 72Bv2: need a scaling_factor for router experts + scaling_factor=scaling_factor, + global_num_experts=global_num_experts, + expert_map=expert_map, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) + + +def forward_oot( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scaling_factor: Optional[float] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", +) -> torch.Tensor: + # print("scaling_factor_test", scaling_factor) + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scaling_factor=scaling_factor, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + assert global_num_experts is not None + + result = fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, # expert_map, # use moe routing op !!!!!!!!!!!!!!!!! + apply_router_weight_on_input=apply_router_weight_on_input, + max_num_tokens=None) + + return result + + +from omni.models.config_loader.loader import model_extra_config +UNQUANT_MODE = 0 +STATIC_QUANT_MODE = 1 +DYNAMIC_QUANT_MODE = 2 +import os +from vllm.distributed import get_world_group, get_pp_group, get_ep_group, get_tp_group +from vllm.platforms import current_platform +MOE_DP_CHUNK_SIZE = 256 + + +def FusedMoE__init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + # patch for pangu 72Bv2: need a scaling_factor for router experts + scaling_factor: Optional[float] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + first_k_dense_replace: int = 3, + **kwargs +): + super(FusedMoE, self).__init__() + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + vllm_config = get_current_vllm_config() + self.moe_parallel_config: FusedMoEParallelConfig = ( + FusedMoEParallelConfig.make( + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size if dp_size is not None else + get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config)) + + self.global_num_experts = num_experts + + # For smuggling this layer into the fused moe custom op + self.use_direct_call = self.dp_size == 1 + if not self.use_direct_call: + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + + + print("self.use_ep", self.use_ep) + # Determine expert maps + if self.use_ep: + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + print("self.expert_map", self.expert_map) + else: + self.local_num_experts, self.expert_map = (self.global_num_experts, + None) + + self.top_k = top_k + + check(intermediate_size % self.tp_size == 0) + self.hidden_size = hidden_size + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + check(num_expert_group is not None and topk_group is not None) + self.num_expert_group = num_expert_group + self.topk_group = topk_group + # patch for pangu 72Bv2: need a scaling_factor for router experts + self.scaling_factor = scaling_factor + self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + self.apply_router_weight_on_input = apply_router_weight_on_input + self.activation = activation + + # Only support float8 for now. + quant_dtype = params_dtype + + + moe = MoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=params_dtype, + # quant_dtype=quant_dtype, + max_num_tokens=None, + ) + self.moe_config = moe + self.quant_config = quant_config + + # Note: get_quant_method will look at the layer's local_num_experts + # for heuristic purposes, so it must be initialized first. + quant_method: Optional[QuantizeMethodBase] = None + quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None + else quant_config.get_quant_method(self, prefix)) + + check(quant_method is not None) + check(isinstance(quant_method, FusedMoEMethodBase)) + self.quant_method = quant_method + + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + self.quant_method.create_weights(layer=self, **moe_quant_params) + + +import inspect +def check(judge, prefix=None): + if not judge: + frame = inspect.stack()[1] + raise RuntimeError(f"file:{frame.filename} line:{frame.lineno} prefix:{prefix}") + +def FusedMoE_forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + check(self.quant_method is not None) + # if (self.moe_parallel_config.use_pplx_kernels + # or self.moe_parallel_config.use_deepep_ll_kernels): + # return self.forward_impl_chunked(hidden_states, router_logits) + + do_naive_dispatch_combine: bool = ( + self.dp_size > 1 + and not self.moe_parallel_config.use_deepep_ht_kernels) + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + # patch for pangu 72Bv2: need a scaling_factor for router experts + scaling_factor=self.scaling_factor, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + ) + + if do_naive_dispatch_combine: + final_hidden_states = get_ep_group().combine(final_hidden_states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs. + final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) + + return final_hidden_states + + +def Attention_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. + output_shape: Optional[torch.Size] = None, + # patch for pangu 72Bv2 with attention sink + sink_query: Optional[torch.Tensor] = None, + sink_key: Optional[torch.Tensor] = None, + sink_value: Optional[torch.Tensor] = None, + v_head_size: Optional[int] = None, +) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(query, key, value) + # print("self.use_output_test", self.use_output) #true + if self.use_output: + output_shape = (output_shape + if output_shape is not None else query.shape) + output = torch.empty(output_shape, + dtype=query.dtype, + device=query.device) + hidden_size = output_shape[-1] + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + # We skip reshaping query, key and value tensors for the MLA + # backend since these tensors have different semantics and are + # processed differently. + if not self.use_mla: + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + # patch for pangu 72Bv2: v head size is different from q and k + if v_head_size is None: + v_head_size = hidden_size // self.num_heads + output = output.view(-1, self.num_heads, v_head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + # patch for pangu 72Bv2: v head size is different from q and k + value = value.view(-1, self.num_kv_heads, value.shape[-1] // self.num_kv_heads) + if self.use_direct_call: + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output, + # patch for pangu 72Bv2 with attention sink + **(dict(sink_query=sink_query, + sink_key=sink_key, + sink_value=sink_value, + v_head_size=v_head_size) if sink_query is not None else {})) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name) + return output.view(-1, hidden_size) + else: + if self.use_direct_call: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward(self, query, key, value, + self_kv_cache, attn_metadata) + else: + return torch.ops.vllm.unified_attention( + query, key, value, self.layer_name) + + +def patch_fused_moe_ops(): + FusedMoE.__init__ = FusedMoE__init__ + FusedMoE.forward_impl = FusedMoE_forward_impl + UnquantizedFusedMoEMethod.apply = UnquantizedFusedMoEMethod_apply + UnquantizedFusedMoEMethod.forward_oot = forward_oot + Attention.forward = Attention_forward + logger.info("UnquantizedFusedMoEMethod.forward_oot is replaced.") + diff --git a/omni/models/pangu/pangu_pro_moe_v2/pangu_moe_v2.py b/omni/models/pangu/pangu_pro_moe_v2/pangu_moe_v2.py new file mode 100644 index 000000000..7b34d5b3d --- /dev/null +++ b/omni/models/pangu/pangu_pro_moe_v2/pangu_moe_v2.py @@ -0,0 +1,1341 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch_npu +from torch import nn +from torch.nn import Parameter +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import ( + divide, + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import get_dp_group, get_tp_group, get_world_group +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + LinearBase, + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader, sharded_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import ( + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors + +from .fused_moe import patch_fused_moe_ops + +enable_graph_mode = False +# from vllm_ascend.ascend_config import get_ascend_config #!!!!!!!!!!!!!!!!!!!!!!! +# from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ + +# from pangu_infer.pangu.utils.utils import print_rank_0, debug_data +import inspect +PANGU_DEBUG = False +def debug_data(data, prefix=None, rank=0): + if not PANGU_DEBUG: + return + if isinstance(data, torch.Tensor): + suffix = f"{data} {data.dtype} {data.shape} {data.float().sum()}" + elif isinstance(data, list): + suffix = f"{data}" + elif isinstance(data, int): + suffix = f"{data}" + elif isinstance(data, str): + suffix = f"{data}" + else: + suffix = f"{data}" + local_rank = torch.distributed.get_rank() + if rank == -1 or rank == local_rank: + frame = inspect.stack()[1] + data_name = frame.code_context[0].split('(')[1].split(',')[0] + print(f"rank:{local_rank} file:{frame.filename} line:{frame.lineno} prefix:{prefix} {data_name} {suffix}", + flush=True) + +logger = init_logger(__name__) + +_ROUTER_SCALE = None + + +def use_h2p(): + # only use H2P when dp_size > 1. + if get_dp_group().world_size > 1: + return True + return False + + +# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear. +# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp). +class CustomMergedColumnParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the last dimension. + output_size = sum(output_sizes) + self.output_sizes = output_sizes + self.tp_size = get_world_group().world_size + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [divide(output_size, self.tp_size) for output_size in self.output_sizes] + + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + ) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + if self.quant_method is None: + raise ValueError("CustomMergedColumnParallelLinear self.quant_method is None") + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=(self.weight_loader), + ) + if bias: + self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + if loaded_shard_id >= len(self.output_sizes): + raise ValueError(f"loaded_shard_id {loaded_shard_id} >= len(self.output_sizes) {len(self.output_sizes)}.") + + tp_rank = get_world_group().rank_in_group + tp_size = get_world_group().world_size + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions." + ) + + if param_data.shape != loaded_weight.shape: + raise ValueError("param_data.shape should be same with loaded_weight.shape, " + f"but param_data.shape is {param_data.shape}, loaded_weight.shape is {loaded_weight.shape}") + param_data.copy_(loaded_weight) + + def forward(self, input_) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + if self.quant_method is None: + raise ValueError("CustomMergedColumnParallelLinear self.quant_method is None") + output_parallel = self.quant_method.apply(self, input_, bias) + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + +# This class is adapted from vllm.model_executor.layers.linear.ColumnParallelLinear. +# It is used to customize rearrange the qkv_linear from qqqqkkkkvvvv to qkvqkvqkvqkv when use tp. +class CustomQKVRearrangeColumnParallelLinear(ColumnParallelLinear): + def __init__( + self, + config: PretrainedConfig, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[list[int]] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.config = config + super().__init__( + input_size, + output_size, + bias, + gather_output, + skip_bias_add, + params_dtype, + quant_config, + output_sizes, + prefix, + return_bias=return_bias, + ) + + def weight_rearrange(self, loaded_weight): + tp_size = get_tp_group().world_size + # if tp_size is 1, we do not need to rearrange the weight + if tp_size == 1: + return loaded_weight + else: + qk_nope_dim = getattr(self.config, "qk_nope_dim", None) + qk_rope_dim = getattr(self.config, "qk_rope_dim", None) + v_channels = getattr(self.config, "v_channels", None) + num_kv_heads = getattr(self.config, "num_key_value_heads", None) + num_heads = getattr(self.config, "num_attention_heads", None) + head_dim = qk_nope_dim + qk_rope_dim + + q_size = num_heads * head_dim + k_size = num_kv_heads * head_dim + v_size = num_kv_heads * v_channels + q_weight, k_weight, v_weight = loaded_weight.split([q_size, k_size, v_size], dim=0) + + q_origin_dim = q_weight.size(0) + k_origin_dim = k_weight.size(0) + v_origin_dim = v_weight.size(0) + assert q_origin_dim % tp_size == 0, f"tp_size is not correct. tp_size {tp_size} must be divisible by q_origin_dim {q_origin_dim}" + assert k_origin_dim % tp_size == 0, f"tp_size is not correct. tp_size {tp_size} must be divisible by k_origin_dim {k_origin_dim}" + assert v_origin_dim % tp_size == 0, f"tp_size is not correct. tp_size {tp_size} must be divisible by v_origin_dim {v_origin_dim}" + + q_weight = q_weight.reshape(tp_size, q_origin_dim//tp_size, -1) + k_weight = k_weight.reshape(tp_size, k_origin_dim//tp_size, -1) + v_weight = v_weight.reshape(tp_size, v_origin_dim//tp_size, -1) + + + loaded_weight_rearrange = [] + for i in range(tp_size): + loaded_weight_rearrange.append(q_weight[i]) + loaded_weight_rearrange.append(k_weight[i]) + loaded_weight_rearrange.append(v_weight[i]) + + loaded_weight_rearrange = torch.cat(loaded_weight_rearrange, dim=0) + + return loaded_weight_rearrange + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + ): + # Based on tp_size, rearrange the pattern from q1q2k1k2v1v2 to q1k1v1q2k2v2. + loaded_weight_rearrange = self.weight_rearrange(loaded_weight) + super().weight_loader(param, loaded_weight_rearrange) + + def weight_loader_v2( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + ): + # Based on tp_size, rearrange the pattern from q1q2k1k2v1v2 to q1k1v1q2k2v2. + loaded_weight_rearrange = self.weight_rearrange(loaded_weight) + super().weight_loader_v2(param, loaded_weight_rearrange) + + +# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear. +# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp) +# and detach communication to enable customized communication algorithms(e.g., H2P). +class CustomRowParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + group=None, + ): + # Divide the weight matrix along the first dimension. + self.group = group if group is not None else get_world_group() + self.tp_rank = self.group.rank_in_group + self.tp_size = self.group.world_size + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + ) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + if self.quant_method is None: + raise ValueError("CustomRowParallelLinear self.quant_method is None") + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=(self.weight_loader), + ) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the results can lead to incorrect results") + + if bias: + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = self.group.rank_in_group + input_dim = getattr(param, "input_dim", None) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + is_sharded_weight = is_sharded_weight + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + if param_data.shape != loaded_weight.shape: + raise ValueError("param_data.shape should be same with loaded_weight.shape, " + f"but param_data.shape is {param_data.shape}, loaded_weight.shape is {loaded_weight.shape}") + param_data.copy_(loaded_weight) + + def forward(self, input_) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + input_parallel = input_ + + # Matrix multiply. + if self.quant_method is None: + raise ValueError("CustomRowParallelLinear self.quant_method is None") + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output = self.quant_method.apply(self, input_parallel, bias=bias_) + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class PanguProMoEMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + if not use_h2p(): + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + else: + self.gate_up_proj = CustomMergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = CustomRowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") + self.act_fn = SiluAndMul() # partial(torch_npu.npu_swiglu, dim=-1) in training + + def forward(self, x, attn_metadata: Optional[AttentionMetadata] = None): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class PanguProMoESparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." + ) + + self.num_experts_per_tok = config.num_experts_per_tok + self.routed_scaling_factor = config.routed_scaling_factor + self.norm_topk_prob = config.norm_topk_prob + + # (VllmWorker rank=0 pid=100075) self.num_experts_per_tok 8 + # (VllmWorker rank=0 pid=100075) self.routed_scaling_factor 2.5 + # (VllmWorker rank=0 pid=100075) self.norm_topk_prob True + # (VllmWorker rank=0 pid=100075) use_h2p() False + + + if config.router_enable_expert_bias: + self.e_score_correction_bias = Parameter(torch.empty(self.num_experts, dtype=torch.float32)) + else: + self.register_parameter("e_score_correction_bias", None) + + # on 300I Duo platform, we find that num_voted_experts set to 5 achieves + # good performance without sacrifice too much accuracy. for other platform, + # this is set to 8 to use original pangu grouped topk. + + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_topk_prob, + quant_config=quant_config, + scaling_factor=self.routed_scaling_factor, + prefix=f"{prefix}.experts", + scoring_func='sigmoid', + e_score_correction_bias=self.e_score_correction_bias + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + if config.shared_expert_intermediate_size > 0: + self.shared_experts = PanguProMoEMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_experts = None # type: ignore + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + debug_data(shared_output, f"shared_output") + + router_logits, _ = self.gate(hidden_states) # router_logits: (num_tokens, n_experts) + # debug_data(router_logits, f"router_logits") + if not use_h2p(): + e_hidden_states = self.experts.forward_impl(hidden_states=hidden_states, router_logits=router_logits) + debug_data(e_hidden_states, f"e_hidden_states") + else: + # when using h2p, we have to skip communication in vLLM + # native FusedMoE. here we need to design a better FusedMoE + # (maybe using AscendFusedMoE) to enable these different + # communication schema. + e_hidden_states = self.experts.quant_method.apply( + layer=self.experts, + x=hidden_states, + router_logits=router_logits, + top_k=self.experts.top_k, + scoring_func='sigmoid', + renormalize=self.norm_topk_prob, + scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + use_grouped_topk=False, + global_num_experts=self.experts.global_num_experts, + expert_map=self.experts.expert_map) + if isinstance(e_hidden_states, tuple): + if len(e_hidden_states) == 4: + e_hidden_states, shared_hidden_states, expert_token_num, group_list_type = e_hidden_states + else: + e_hidden_states, expert_token_num, group_list_type = e_hidden_states + final_hidden_states = e_hidden_states + debug_data(final_hidden_states, f"final_hidden_states") + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if not use_h2p(): + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class PanguProMoEV2Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + if self.total_num_heads % tp_size != 0: + raise ValueError(f"self.total_num_heads % tp_size must be 0, but it is {self.total_num_heads % tp_size}") + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + if self.total_num_kv_heads % tp_size != 0: + raise ValueError("self.total_num_kv_heads % tp_size must be 0, " + F"but it is {self.total_num_kv_heads % tp_size}") + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + if tp_size % self.total_num_kv_heads != 0: + raise ValueError("self.total_num_kv_heads % tp_size must be 0, " + F"but it is {tp_size % self.total_num_kv_heads}") + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.qk_nope_dim = getattr(config, "qk_nope_dim", None) + self.qk_rope_dim = getattr(config, "qk_rope_dim", None) + self.v_channels = getattr(config, "v_channels", None) + self.head_dim = self.qk_nope_dim + self.qk_rope_dim + self.q_size = self.num_heads * self.head_dim + self.k_size = self.num_kv_heads * self.head_dim + self.v_size = self.num_kv_heads * self.v_channels + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.param_sink_number = getattr(config, "param_sink_number", 0) + self.param_sink_with_value = getattr(config, "param_sink_with_value", False) + self.param_sink_scalar = getattr(config, "param_sink_scalar", None) + self.param_sink_of_head_num = getattr(config, "param_sink_of_head_num", False) + + # use ColumnParallelLinear rather than QKVParallelLinear to keep same with training + # CustomQKVRearrangeColumnParallelLinear can arrange the qkv order when use tp + self.qkv_proj = CustomQKVRearrangeColumnParallelLinear( + config, + self.hidden_size, + (self.q_size + self.k_size + self.v_size) * tp_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + if use_h2p(): + self.o_proj = CustomRowParallelLinear( + self.total_num_heads * self.v_channels, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + group=get_tp_group(), + ) + else: + self.o_proj = RowParallelLinear( + self.total_num_heads * self.v_channels, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # native support for partial rope: qk[:64] + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.qk_rope_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + self.torchair_graph_enabled = enable_graph_mode + + if self.param_sink_number > 0: + self.param_sink_query = torch.zeros(( + self.param_sink_number, + self.num_heads, + self.head_dim), + device=torch.npu.current_device(), + dtype=config.torch_dtype, + ) + if self.param_sink_of_head_num: + self.param_sink_num_heads_per_partition = self.num_heads + self.q_mult = divide(self.num_heads, self.num_kv_heads) + else: + self.param_sink_num_heads_per_partition = self.num_kv_heads + if self.param_sink_scalar: + self.param_sink_key_zero_pad = torch.zeros(( + self.param_sink_number, + self.param_sink_num_heads_per_partition, + self.param_sink_scalar - 1), + device=torch.npu.current_device(), + dtype=config.torch_dtype, + ) + self.param_sink_key = torch.nn.Parameter( + torch.empty( + (self.param_sink_number, self.param_sink_num_heads_per_partition), + device=torch.npu.current_device(), + dtype=config.torch_dtype, + ) + ) + setattr(self.param_sink_key, 'allreduce', True) + else: + self.param_sink_key = torch.nn.Parameter( + torch.empty(( + self.param_sink_number, + self.param_sink_num_heads_per_partition, + self.head_dim), + device=torch.npu.current_device(), + dtype=config.torch_dtype, + ) + ) + setattr(self.param_sink_key, 'allreduce', True) + if self.param_sink_with_value: + self.param_sink_value = torch.nn.Parameter( + torch.empty(( + self.param_sink_number, + self.param_sink_num_heads_per_partition, + self.v_channels), + device=torch.npu.current_device(), + dtype=config.torch_dtype, + ) + ) + setattr(self.param_sink_value, 'allreduce', True) + else: + self.param_sink_value = torch.zeros(( + self.param_sink_number, + self.param_sink_num_heads_per_partition, + self.v_channels), + device=torch.npu.current_device(), + dtype=config.torch_dtype, + ) + self.enable_sink = False + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + k = self.k_layernorm(k.view(-1, self.num_kv_heads, self.head_dim)) + q, k = self.rotary_emb(positions, q.contiguous(), k) + + q = q.view(-1, self.q_size) + k = k.view(-1, self.k_size) + # pad v and attention sink after kv cache update in pangu_infer/patches/vllm_ascend/attention/attention_v1.py + param_sink_key = self.param_sink_key + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + if self.param_sink_number > 0 and attn_metadata is not None: + if hasattr(self, 'k_layernorm') and self.k_layernorm is not None: + param_sink_key = self.k_layernorm(self.param_sink_key) + self.enable_sink = True + if self.torchair_graph_enabled: + forward_kwargs = {"trace_flag": False} + output_shape = (q.shape[0], self.num_heads, self.v_channels) + attn_output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + forward_kwargs["output"] = attn_output + attn_output = self.attn.impl.forward( + self.attn, q, k, v, kv_cache, attn_metadata, + **(dict( + sink_query=self.param_sink_query, + sink_key=param_sink_key, + sink_value=self.param_sink_value, + v_head_size=self.v_channels + ) if self.enable_sink else {}), **forward_kwargs) + else: + + # print("self.num_heads_before", self.num_heads) #16 + # print("self.v_channels", self.v_channels) #128 + # print("q.shape", q.shape) #[2048, 3072] = [2048, 16*192 = 64/4 * 192] + # print("k.shape", k.shape) #[2048, 192] + # print("v.shape", v.shape) #[2048, 128] + attn_output = self.attn( + q, + k, + v, + output_shape=(q.shape[0], self.num_heads * self.v_channels), #2048, 128*16=2048 + **(dict( + # output_shape=(q.shape[0], self.num_heads * self.v_channels), + sink_query=self.param_sink_query, + sink_key=param_sink_key, # k_layernorm online + sink_value=self.param_sink_value, + v_head_size=self.v_channels + ) if self.enable_sink and attn_metadata is not None else {}), + ) + attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels) + output, _ = self.o_proj(attn_output) + return output + + +class PanguProMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + + self.self_attn = PanguProMoEV2Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + self.layer_number = layer_idx + 1 + mlp_only_layers = [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + if (layer_idx not in mlp_only_layers) and (config.num_experts > 0): + self.mlp = PanguProMoESparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = PanguProMoEMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if getattr(config, 'sandwich_norm', False): + self.sandwich_norm = True + self.pre_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.sandwich_norm = False + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + h2p_unpad_idx: Optional[torch.Tensor] = None, + h2p_pad_idx: Optional[torch.Tensor] = None, + is_start_layer: Optional[bool] = False, + ) -> torch.Tensor: + need_h2p_pad = ( + h2p_unpad_idx is not None and h2p_pad_idx is not None and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0] + ) + tp_size = get_tp_group().world_size + debug_data(hidden_states, f"layer{self.layer_number} input") + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + debug_data(hidden_states, f"layer{self.layer_number} input_layernorm output") + if use_h2p(): + if is_start_layer: + if need_h2p_pad: + residual = residual.index_select(dim=0, index=h2p_pad_idx) + residual = torch.tensor_split(residual, tp_size)[get_tp_group().rank_in_group] + else: + if tp_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + if need_h2p_pad: + hidden_states = hidden_states.index_select(dim=0, index=h2p_unpad_idx) + + debug_data(hidden_states, f"layer{self.layer_number} attn input") + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + debug_data(hidden_states, f"layer{self.layer_number} attn output") + + if use_h2p(): + if need_h2p_pad: + hidden_states = hidden_states.index_select(dim=0, index=h2p_pad_idx) + if tp_size > 1: + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_tp_group().device_group, + ) + + debug_data(hidden_states, f"layer{self.layer_number} post_attention_layernorm input") + # Fully Connected + if self.sandwich_norm: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.pre_mlp_layernorm(hidden_states) + else: + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + if use_h2p(): + all_rank_group = get_world_group().device_group + output_size = ( + hidden_states.shape[0] * get_world_group().world_size, + hidden_states.shape[1], + ) + # Allocate output tensor. + output_tensor = torch.empty(output_size, dtype=hidden_states.dtype, device=hidden_states.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, hidden_states, group=all_rank_group) + hidden_states = output_tensor + + debug_data(hidden_states, f"layer{self.layer_number} mlp input") + + hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata) + + debug_data(hidden_states, f"layer{self.layer_number} mlp output") + + if use_h2p(): + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_world_group().device_group, + ) + + if self.sandwich_norm: + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = None + + debug_data(hidden_states, f"layer{self.layer_number} output") + + return hidden_states, residual + + +@support_torch_compile +class PanguProMoEModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + global enable_graph_mode + from vllm.config import CompilationLevel + enable_graph_mode = (vllm_config.npu_compilation_config.level != CompilationLevel.NO_COMPILATION) + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PanguProMoEDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + debug_data(input_ids, "token ids") + debug_data(positions, "position ids") + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + if self.quant_method is None: + raise ValueError("PanguProMoEModel.forward self.quant_method is None") + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + if use_h2p(): + # calculate necessary padding/unpadding idx before model forward. + + # the attn_metadata will be passed directly when use torchair. + # if attn_meatadata is not passed, we try to get it from forward_context. + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + max_tokens_across_dp = get_forward_context().max_tokens_across_dp + tp_size = get_tp_group().world_size + # reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks. + # we need pad it before if the shape can't be divided by group size. + # for h2p, we need pad it so that it can be divided by tp_size. + h2p_padded_len = ( + (tp_size - (max_tokens_across_dp % tp_size)) % tp_size + max_tokens_across_dp - hidden_states.shape[0] + ) + h2p_unpad_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device, dtype=torch.int32) + h2p_pad_idx = torch.cat( + [ + h2p_unpad_idx, + torch.zeros(h2p_padded_len, dtype=torch.int32, device=hidden_states.device), + ] + ) + else: + h2p_unpad_idx = None + h2p_pad_idx = None + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + kv_caches[i - self.start_layer] if kv_caches is not None else None, + attn_metadata, + h2p_unpad_idx, + h2p_pad_idx, + i == self.start_layer, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) + hidden_states = self.norm(hidden_states) + debug_data(hidden_states, f"final_layernorm output") + if use_h2p(): + if get_tp_group().world_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]: + hidden_states = hidden_states.index_select(dim=0, index=h2p_unpad_idx) + return hidden_states + + +class PanguProMoEV2ForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + patch_fused_moe_ops() + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = PanguProMoEModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # print_rank_0(f"PanguProMoEForCausalLM {self.model}") + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + selected_indices: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + debug_data(logits, f"logits_processor output") + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + tp_size = get_tp_group().world_size + tp_rank = get_tp_group().rank_in_group + stacked_params_mapping = [ + # param_name, shard_name, shard_id + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # param_name, weight_name, expert_id, shard_id + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) # from model + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + # ======================================================= + # BF: add this to load with less layers + if valid_name_layer(name, self): + continue + if valid_name(name): + continue + + if name.endswith("k_proj.kv_cache_scale"): + remapped_kv_scale_name = name.replace("k_proj.kv_cache_scale", "attn.key_antiquant_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded." + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + loaded_weight = torch.tensor_split(loaded_weight, tp_size, dim=0)[tp_rank] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + if name.endswith("v_proj.kv_cache_scale"): + remapped_kv_scale_name = name.replace("v_proj.kv_cache_scale", "attn.value_antiquant_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded." + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + loaded_weight = torch.tensor_split(loaded_weight, tp_size, dim=0)[tp_rank] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + flag, name = valid_stack_mappingng(param_name, weight_name, name, params_dict, self) + if flag: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + expert_flag, name = valid_expert_mapping(weight_name, param_name, name, params_dict, self) + if expert_flag: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + scale_flag, name = valid_scale_mapping(name, params_dict, self) + if scale_flag: + continue + param = params_dict[name] + if name.endswith("param_sink_key") or name.endswith("param_sink_value"): + weight_loader = getattr(param, "weight_loader", sharded_weight_loader(-2)) # [S,N,D] + else: + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def valid_name(name: str) -> bool: + if "rotary_emb.inv_freq" in name or "module" in name or name.endswith("kv_cache_offset"): + return True + else: + return False + + +def valid_name_layer(name: str, self) -> bool: + if "layers" in name: + layer_idx = int(name.split("layers.")[-1].split(".")[0]) + if layer_idx >= self.model.end_layer: + return True + return False + + +def valid_stack_mappingng(param_name, weight_name, name: str, params_dict, self): + if weight_name not in name: + return True, name + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + return True, name + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: + return True, name + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + return True, name + if name not in params_dict: + return True, name + return False, name + + +def valid_expert_mapping(weight_name, param_name, name, params_dict, self): + if weight_name not in name: + return True, name + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + return True, name + # Skip loading extra bias for GPTQ models. + if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: + return True, name + return False, name + + +def valid_scale_mapping(name, params_dict, self): + # Skip loading extra bias for GPTQ models. + if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: + return True, name + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + return True, name + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace(".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded." + ) + return True, name + else: + name = remapped_kv_scale_name + return False, name diff --git a/tools/scripts/send_http.py b/tools/scripts/send_http.py new file mode 100644 index 000000000..2eae136bc --- /dev/null +++ b/tools/scripts/send_http.py @@ -0,0 +1,41 @@ +import asyncio +import aiohttp # 需要安装aiohttp库 + +async def send_single_request(session, url, data): + async with session.post(url, json=data) as response: + return await response.json() + +async def main(): + url = "http://127.0.0.1:9555/v1/completions" + + # 准备两条不同的请求数据 + requests_data = [ + { + "model": "PanguProMoE", + "prompt": ["请介绍一下苏州,不少于50字"], + "max_tokens": 12, + "temperature": 0, + "top_p": 1, + "top_k": 1 + }, + { + "model": "PanguProMoE", + "prompt": ["请介绍一下苏州,不少于50字"], + "max_tokens": 12, + "temperature": 0, + "top_p": 1, + "top_k": 1 + } + ] + + async with aiohttp.ClientSession() as session: + # 并发发送所有请求 + tasks = [send_single_request(session, url, data) for data in requests_data] + responses = await asyncio.gather(*tasks) + + # 处理所有响应 + for i, response in enumerate(responses): + print(f"响应 {i+1}: {response}") + +# 运行异步主函数 +asyncio.run(main()) \ No newline at end of file diff --git a/tools/scripts/test_start_api_servers_pangu_72Bv2.sh b/tools/scripts/test_start_api_servers_pangu_72Bv2.sh new file mode 100644 index 000000000..d2209a561 --- /dev/null +++ b/tools/scripts/test_start_api_servers_pangu_72Bv2.sh @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +export GLOO_SOCKET_IFNAME=enp23s0f3 +export TP_SOCKET_IFNAME=enp23s0f3 +# enp67s0f5 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +export VLLM_USE_V1=1 +export VLLM_WORKER_MULTIPROC_METHOD=fork +export VLLM_ENABLE_MC2=0 +export USING_LCCL_COM=0 +export ASCEND_LAUNCH_BLOCKING=1 +export ENABLE_PREFILL_TND=1 +# export OMNI_USE_QWEN=1 + +python start_api_servers.py \ + --num-servers 1 \ + --model-path /mnt/sfs_turbo/bucket-910c-6055/rrj/model/MoE_74BA15B/vllm_ascend_mtp/eighth_2550b/fix/iter_0009500/ \ + --master-ip 8.8.8.8 \ + --tp 4 \ + --master-port 35678 \ + --served-model-name PanguProMoE \ + --log-dir apiserverlog \ + --extra-args "--enforce-eager --enable-expert-parallel " \ + --gpu-util 0.6 \ + --base-api-port 9555 \ + --max-model-len 4096 \ + --no-enable-prefix-caching \ + --additional-config '{ "enable_hybrid_graph_mode": true}' # 混部模式开启enable_hybrid_graph_mode -- Gitee From 28f3a56d7aae6b20aa9015fec0e9792e767a6aec Mon Sep 17 00:00:00 2001 From: Nickyi Date: Tue, 11 Nov 2025 10:45:15 +0800 Subject: [PATCH 2/6] update --- omni/layers/attention/backend/attention.py | 285 +++++++----------- .../pangu/pangu_pro_moe_v2/pangu_moe_v2.py | 20 +- 2 files changed, 124 insertions(+), 181 deletions(-) diff --git a/omni/layers/attention/backend/attention.py b/omni/layers/attention/backend/attention.py index 803a20689..c5d5f888c 100644 --- a/omni/layers/attention/backend/attention.py +++ b/omni/layers/attention/backend/attention.py @@ -45,6 +45,21 @@ from omni.models.config_loader.loader import model_extra_config NZ_DIM = 16 +import torchair as tng +import vllm.envs as envs + +config = tng.CompilerConfig() +# Set the export image structure file format +# config.debug.graph_dump.type = "py" +config.experimental_config.frozen_parameter = True +config.experimental_config.tiling_schedule_optimize = True +torch.npu.set_compile_mode(jit_compile=False) +npu_backend = tng.get_npu_backend(compiler_config=config) + +@torch.compile(backend=npu_backend, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, dynamic=True) +def print_tensor(x): + tng.ops.npu_print("print, tensor:", x) + class AscendAttentionState(Enum): PrefillNoCache = 0 PrefillCacheHit = 1 @@ -509,12 +524,11 @@ class AscendAttentionBackendImpl(AttentionImpl): ) self.use_tnd_pa = model_extra_config.operator_opt_config.use_tnd_pa self.kv_stream = kv_stream - self.use_sink = getattr(cur_vllm_config.model_config.hf_config, "v_channels", None) is not None def forward(self, *args, **kwargs): if self.use_tnd_pa: return self.forward_pa(*args, **kwargs) - elif self.use_sink: + elif "sink_key" in kwargs: return self.forward_sink(*args, **kwargs) else: return self.forward_vanilla(*args, **kwargs) @@ -867,19 +881,18 @@ class AscendAttentionBackendImpl(AttentionImpl): Returns: shape = [batch_size * seq_len, num_heads, head_size] """ + sink_key_flag = (sink_key is not None) + assert sink_key_flag num_tokens = query.shape[0] if v_head_size == None: v_head_size = self.head_size if output is None: output = torch.empty(num_tokens, self.num_heads, - self.head_size, + v_head_size, dtype=query.dtype, device=query.device) - # print("num_tokens", num_tokens) #2048 - # print("self.num_heads", self.num_heads) #16 - # print("self.head_size", self.head_size) #192 - # print("self.hidden_size", self.hidden_size) #3072 + if attn_metadata is None: return output.view(num_tokens, -1) @@ -893,34 +906,24 @@ class AscendAttentionBackendImpl(AttentionImpl): "PallasAttentionBackendImpl") # View q k v to BSH. special_head_size_flag = (self.head_size == 192) - sink_key_flag = (sink_key is not None) - # print("query.shape", query.shape) #[256, 16, 192] - # print("key.shape", key.shape) #[256, 1, 192] - # print("value.shape", value.shape) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) - if sink_key_flag: - value = value.view(-1, self.num_kv_heads, v_head_size) # v has a different h - else: - value = value.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, v_head_size) # v has a different h value = value.contiguous() # update kv cache if kv_cache[0].numel() > 0 or kv_cache[1].numel(): self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] - # print("self.value_cache.shape", self.value_cache.shape) block_size = self.key_cache.shape[1] # kv_cache: shape = [2, num_blocks, block_size, # num_kv_heads * head_size] cast_key = key.reshape(-1, 1, self.num_kv_heads * self.head_size) - if sink_key_flag: - cast_value = value.reshape(-1, 1, self.num_kv_heads * v_head_size) - else: - cast_value = value.reshape(-1, 1, self.num_kv_heads * self.head_size) + cast_value = value.reshape(-1, 1, self.num_kv_heads * v_head_size) + if attn_metadata.attn_state != AscendAttentionState.DecodeOnly: # if prefill does not use paged attention, # (1) saving keys and values into kv_cache, and @@ -932,11 +935,6 @@ class AscendAttentionBackendImpl(AttentionImpl): else: stream_for_reshape_and_cache = torch.npu.current_stream() with torch.npu.stream(stream_for_reshape_and_cache): - # print("key.shape", key.shape) #[8, 1, 192] - # print("value.shape", value.shape) #[8, 1, 128] - # print("self.key_cache.view(self.key_cache.shape[0], block_size, self.num_kv_heads, self.head_size).shape", self.key_cache.view(self.key_cache.shape[0], block_size, self.num_kv_heads, self.head_size).shape) #[140, 128, 1, 192] - # print("self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, self.v_head_size).shape", self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, v_head_size).shape) #[140, 128, 1, 128] - # print("attn_metadata.slot_mapping.int().shape", attn_metadata.slot_mapping.int().shape) #[8] torch_npu._npu_reshape_and_cache( key, value, @@ -945,35 +943,14 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata.slot_mapping.int() ) else: - # print("self.key_cache", self.key_cache.shape) #[140, 128, 192] - # print("attn_metadata.slot_indices", attn_metadata.slot_indices.shape) #[256, 2] - # print("cast_key", cast_key.shape) #[256, 1, 192] - # print("self.value_cache", self.value_cache.shape) - # print("attn_metadata.slot_indices", attn_metadata.slot_indices) - # print("cast_value", cast_value.shape) ##[256, 1, 128] - - if sink_key_flag: - torch_npu._npu_reshape_and_cache( - cast_key, - cast_value, - self.key_cache.view(self.key_cache.shape[0], block_size, self.num_kv_heads, self.head_size), - self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, v_head_size), - attn_metadata.slot_mapping.int() - ) - else: - torch_npu.scatter_update_(self.key_cache, attn_metadata.slot_indices, cast_key, -2) - torch_npu.scatter_update_(self.value_cache, attn_metadata.slot_indices, cast_value, -2) + torch_npu.scatter_update_(self.key_cache, attn_metadata.slot_indices, cast_key, -2) + torch_npu.scatter_update_(self.value_cache, attn_metadata.slot_indices, cast_value, -2) + - if sink_key_flag and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: # kv cache start from block 1 and slots 128, so we store sink in block 0. slots = torch.arange(0, 128, device=sink_key.device, dtype=torch.int32) bsz = attn_metadata.query_lens.shape[0] - # print("sink_key.shape", sink_key.shape) - # print("sink_value.shape", sink_value.shape) - - # print("self.key_cache.shape", self.key_cache.shape) - # print("self.value_cache.shape", self.value_cache.shape) - # print("slots.shape", slots.shape) torch_npu._npu_reshape_and_cache( key=sink_key, value=sink_value, @@ -981,7 +958,6 @@ class AscendAttentionBackendImpl(AttentionImpl): value_cache=self.value_cache.view(self.value_cache.shape[0], block_size, self.num_kv_heads, v_head_size), slot_indices=slots) - if hasattr(layer, 'quant_method'): pass # V0-Style scheduler situation. @@ -1004,8 +980,6 @@ class AscendAttentionBackendImpl(AttentionImpl): actual_seq_lengths_kv=attn_metadata.seq_lens_list, atten_mask=AscendAttentionBackendImpl.SHARE_MASK_TRIL_SPARSE, )[0].view(-1, self.num_heads, self.head_size) - # print("attn_output.shape", attn_output.shape) - # print("output.shape", output.shape) output = output.view_as(attn_output) output.copy_(attn_output) @@ -1030,143 +1004,99 @@ class AscendAttentionBackendImpl(AttentionImpl): if attn_metadata is None: raise RuntimeError("attn_metadata must not be None") - - if sink_key_flag: - bsz = attn_metadata.query_lens.shape[0] - cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() - cu_seqlen_q = torch.tensor(cu_seqlen_q, device=key.device) - cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) - key_list = [] - value_list = [] - for i in range(bsz): - k = key[cu_seqlen_q[i]:cu_seqlen_q[i+1]] - v = value[cu_seqlen_q[i]:cu_seqlen_q[i+1]] - key_list.append(torch.cat([sink_key, k], dim=0)) - value_list.append(torch.cat([sink_value, v], dim=0)) - key = torch.cat(key_list, dim=0) - value = torch.cat(value_list, dim=0) - - - if special_head_size_flag: - atten_mask = ~torch.tril( - torch.ones((2048, 2048), device='npu', dtype=torch.bool) - ) - if sink_key_flag: - seq_lens_with_sink = attn_metadata.seq_lens + sink_key.shape[0] - cu_seqlen = [0] + (attn_metadata.seq_lens.tolist() if sink_key is None else seq_lens_with_sink.tolist()) - cu_seqlen = torch.tensor(cu_seqlen, device="npu") - cu_seqlen = torch.cumsum(cu_seqlen, dim=0)[1:] - attn_output = torch_npu.npu_fused_infer_attention_score( - query, - key, - value, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=3, - pre_tokens=2147483647, - next_tokens=0, - atten_mask=atten_mask, - inner_precise=0, - actual_seq_lengths=cu_seqlen_q[1:], - actual_seq_lengths_kv=cu_seqlen)[0] - output.copy_(attn_output) - output = output.view(-1, self.num_heads * v_head_size) - - - else: - actual_seq_qlen = np.array(attn_metadata.query_lens).cumsum().tolist() - actual_seq_kvlen = np.array(attn_metadata.seq_lens).cumsum().tolist() - attn_output = torch_npu.npu_fused_infer_attention_score( - query[:actual_seq_qlen[-1],:,:], - key[:actual_seq_qlen[-1],:,:], - value[:actual_seq_qlen[-1],:,:], - num_heads = self.num_heads, - num_key_value_heads = self.num_kv_heads, - input_layout = "TND", - scale = self.scale, - sparse_mode = 3, - actual_seq_lengths = actual_seq_qlen, - actual_seq_lengths_kv = actual_seq_kvlen, - atten_mask = AscendAttentionBackendImpl.SHARE_MASK_TRIL_SPARSE, - )[0].view(-1, self.num_heads, self.head_size) - - output[:actual_seq_qlen[-1], :].copy_(attn_output) + bsz = attn_metadata.query_lens.shape[0] + cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() + cu_seqlen_q = torch.tensor(cu_seqlen_q, device=key.device) + cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) + key_list = [] + value_list = [] + for i in range(bsz): + k = key[cu_seqlen_q[i]:cu_seqlen_q[i+1]] + v = value[cu_seqlen_q[i]:cu_seqlen_q[i+1]] + key_list.append(torch.cat([sink_key, k], dim=0)) + value_list.append(torch.cat([sink_value, v], dim=0)) + key = torch.cat(key_list, dim=0) + value = torch.cat(value_list, dim=0) + + atten_mask = ~torch.tril( + torch.ones((2048, 2048), device='npu', dtype=torch.bool) + ) + seq_lens_with_sink = attn_metadata.seq_lens + sink_key.shape[0] + cu_seqlen = [0] + (attn_metadata.seq_lens.tolist() if sink_key is None else seq_lens_with_sink.tolist()) + cu_seqlen = torch.tensor(cu_seqlen, device="npu") + cu_seqlen = torch.cumsum(cu_seqlen, dim=0)[1:] + attn_output = torch_npu.npu_fused_infer_attention_score( + query, + key, + value, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=3, + pre_tokens=2147483647, + next_tokens=0, + atten_mask=atten_mask, + inner_precise=0, + actual_seq_lengths=cu_seqlen_q[1:], + actual_seq_lengths_kv=cu_seqlen)[0] + output.copy_(attn_output) + output = output.view(-1, self.num_heads * v_head_size) + if stream_for_reshape_and_cache != torch.npu.current_stream(): torch.npu.current_stream().wait_stream(stream_for_reshape_and_cache) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + # actual_block_tables = attn_metadata.block_tables + # actual_seq_lengths = attn_metadata.seq_lens - if special_head_size_flag: - if sink_key_flag: - # actual_block_tables = attn_metadata.block_tables - # actual_seq_lengths = attn_metadata.seq_lens + block_size = self.value_cache.shape[1] + num_batch = attn_metadata.query_lens.shape[0] + # sink stored in block 0 + block_tables = F.pad(attn_metadata.block_tables, (1, 0, 0, 0), value=0) - block_size = self.value_cache.shape[1] - num_batch = attn_metadata.query_lens.shape[0] - # sink stored in block 0 - block_tables = F.pad(attn_metadata.block_tables, (1, 0, 0, 0), value=0) - - # PA模式actual_seq_lengths累加,actual_seq_lengths_kv不累加;非PA模式都是累加 - cu_seqlen_q = torch.arange(0, num_batch).npu() + 1 - attn_output = torch_npu.npu_fused_infer_attention_score( - query, - self.key_cache.view(-1, block_size, self.num_kv_heads * self.head_size), - self.value_cache.view(-1, block_size, self.num_kv_heads * v_head_size), - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="TND", - scale=self.scale, - block_table=block_tables[:num_batch], - block_size=block_size, - actual_seq_lengths=cu_seqlen_q, - actual_seq_lengths_kv=attn_metadata.seq_lens + sink_key.shape[0], - )[0] + if self.enable_graph_mode: + # actual_seq_lengths = attn_metadata.query_lens.cumsum(dim=0) + actual_seq_lengths_kv = attn_metadata.seq_lens + 128 + attn_output = tng.ops.npu_fused_infer_attention_score( + torch.transpose(query.view(num_batch, -1, self.num_heads, self.head_size), 1, 2), + self.key_cache, + self.value_cache, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BNSD", + scale=self.scale, + block_table=block_tables, + block_size=block_size, + actual_seq_lengths_kv=actual_seq_lengths_kv, + inner_precise=1 + )[0] - output.copy_(attn_output) + attn_output = attn_output[:, :, :, :v_head_size] + output = output.view_as(attn_output) + output.copy_(attn_output) else: - block_num, block_size = self.key_cache.shape[0], self.key_cache.shape[1] - - num_batch = attn_metadata.seq_lens.shape[0] - query = query.view(num_batch, -1, self.num_heads * self.head_size) - block_tables = attn_metadata.block_tables - attn_output = None - if self.enable_graph_mode: - attn_output, _ = tng.ops.npu_fused_infer_attention_score( - torch.transpose(query.view(num_batch, -1, self.num_heads, self.head_size), 1, 2), - self.key_cache, - self.value_cache, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="BNSD", - scale=self.scale, - actual_seq_lengths_kv=attn_metadata.seq_lens, - block_table=block_tables, - block_size=block_size, - inner_precise=1 - ) - else: - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - query, - self.key_cache, - self.value_cache, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="BSH", - scale=self.scale, - actual_seq_lengths_kv=attn_metadata.seq_lens, - block_table=block_tables, - block_size=block_size, - ) - + # PA模式actual_seq_lengths累加,actual_seq_lengths_kv不累加;非PA模式都是累加 + cu_seqlen_q = torch.arange(0, num_batch).npu() + 1 + attn_output = torch_npu.npu_fused_infer_attention_score( + query, + self.key_cache.view(-1, block_size, self.num_kv_heads * self.head_size), + self.value_cache.view(-1, block_size, self.num_kv_heads * v_head_size), + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", + scale=self.scale, + block_table=block_tables[:num_batch], + block_size=block_size, + actual_seq_lengths=cu_seqlen_q, + actual_seq_lengths_kv=attn_metadata.seq_lens + sink_key.shape[0], + )[0] output = output.view_as(attn_output) output.copy_(attn_output) - + # Normal V1 situation. else: # use chunked prefill for head size 192 scenario, like deepseek # paged_attention_splitfuse maybe crash at such scenario - all_key = self.key_cache.view(-1, self.num_kv_heads, self.head_size)[attn_metadata.kv_index].contiguous() all_value = self.value_cache.view(-1, self.num_kv_heads, self.head_size)[attn_metadata.kv_index].contiguous() actual_seq_qlen = np.array(attn_metadata.query_lens).cumsum().tolist() @@ -1188,7 +1118,6 @@ class AscendAttentionBackendImpl(AttentionImpl): output.copy_(attn_output) return output.view(num_tokens, -1) - # return output.view(num_tokens, self.hidden_size) class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True diff --git a/omni/models/pangu/pangu_pro_moe_v2/pangu_moe_v2.py b/omni/models/pangu/pangu_pro_moe_v2/pangu_moe_v2.py index 7b34d5b3d..b7e900430 100644 --- a/omni/models/pangu/pangu_pro_moe_v2/pangu_moe_v2.py +++ b/omni/models/pangu/pangu_pro_moe_v2/pangu_moe_v2.py @@ -70,6 +70,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors +from omni.layers.attention.backend.attention import AscendAttentionState from .fused_moe import patch_fused_moe_ops enable_graph_mode = False @@ -852,6 +853,7 @@ class PanguProMoEDecoderLayer(nn.Module): # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) + self.layer_name = f"{prefix}.self_attn.attn" self.layer_number = layer_idx + 1 mlp_only_layers = [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers if (layer_idx not in mlp_only_layers) and (config.num_experts > 0): @@ -980,14 +982,13 @@ class PanguProMoEDecoderLayer(nn.Module): return hidden_states, residual -@support_torch_compile class PanguProMoEModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() global enable_graph_mode from vllm.config import CompilationLevel - enable_graph_mode = (vllm_config.npu_compilation_config.level != CompilationLevel.NO_COMPILATION) + # enable_graph_mode = (vllm_config.npu_compilation_config.level != CompilationLevel.NO_COMPILATION) config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -1092,7 +1093,7 @@ class PanguProMoEModel(nn.Module): hidden_states = hidden_states.index_select(dim=0, index=h2p_unpad_idx) return hidden_states - +@support_torch_compile class PanguProMoEV2ForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False @@ -1265,6 +1266,19 @@ class PanguProMoEV2ForCausalLM(nn.Module, SupportsPP): loaded_params.add(name) return loaded_params + def should_use_eager_mode(self, *args, **kwargs): + """Return if a layer should use eager mode. This function is + to fit the attention backend of Omni infer. + + Returns: + bool: True for eager mode, False for graph mode + """ + attn_metadata = kwargs.get("attn_metadata", None) + if not attn_metadata: + return True + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.model.layers[self.model.start_layer].layer_name] + return attn_metadata.attn_state != AscendAttentionState.DecodeOnly def valid_name(name: str) -> bool: if "rotary_emb.inv_freq" in name or "module" in name or name.endswith("kv_cache_offset"): -- Gitee From 8bff9a10357500a28486d4c4c0d0c71dea727817 Mon Sep 17 00:00:00 2001 From: Nickyi Date: Tue, 11 Nov 2025 11:19:02 +0800 Subject: [PATCH 3/6] update --- .vscode/settings.json | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 24e2c93c6..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "workbench.editor.wrapTabs": true -} \ No newline at end of file -- Gitee From 37d0a1c992ceaa258329f31e47a58f80bdfab11d Mon Sep 17 00:00:00 2001 From: Nickyi Date: Tue, 11 Nov 2025 16:45:57 +0800 Subject: [PATCH 4/6] update --- omni/layers/attention/backend/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omni/layers/attention/backend/attention.py b/omni/layers/attention/backend/attention.py index c5d5f888c..7db96c8ee 100644 --- a/omni/layers/attention/backend/attention.py +++ b/omni/layers/attention/backend/attention.py @@ -655,7 +655,7 @@ class AscendAttentionBackendImpl(AttentionImpl): device=query.device) if attn_metadata is None: - return output.view(num_tokens, self.hidden_size) + return output.view(num_tokens, -1) if not (layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0): raise RuntimeError("layer._k_scale_float and layer._v_scale_float must both be 1.0") -- Gitee From 717cd42e3836084a84f2dd3bda30dede639cd3ac Mon Sep 17 00:00:00 2001 From: Nickyi Date: Mon, 17 Nov 2025 17:41:13 +0800 Subject: [PATCH 5/6] add chunked_prefill --- omni/adaptors/vllm/worker/npu_model_runner.py | 15 ++-- omni/layers/attention/backend/attention.py | 71 ++++++++++++++----- ...t_start_api_servers_pangu_72Bv2_chunked.sh | 41 +++++++++++ 3 files changed, 105 insertions(+), 22 deletions(-) create mode 100644 tools/scripts/test_start_api_servers_pangu_72Bv2_chunked.sh diff --git a/omni/adaptors/vllm/worker/npu_model_runner.py b/omni/adaptors/vllm/worker/npu_model_runner.py index e25de1ec1..21e5a7a06 100644 --- a/omni/adaptors/vllm/worker/npu_model_runner.py +++ b/omni/adaptors/vllm/worker/npu_model_runner.py @@ -1268,14 +1268,19 @@ class NPUModelRunner(GPUModelRunner): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - kv_caches[layer_name] = self.attn_backends[i].init_kv_cache_each_layer(kv_cache_shape, self.dtype, - self.device, - self.model_config, - self.enable_torchair_graph_mode) + hf_config = self.vllm_config.model_config.hf_config v_channels = getattr(hf_config, "v_channels", None) if v_channels is not None: - kv_caches[layer_name] = (kv_caches[layer_name][0], kv_caches[layer_name][1][...,:v_channels].contiguous()) + kv_caches[layer_name] = self.attn_backends[i].init_kv_cache_each_layer_sink_attetion(kv_cache_shape, self.dtype, + self.device, + self.model_config, + self.enable_torchair_graph_mode,v_channels) + else: + kv_caches[layer_name] = self.attn_backends[i].init_kv_cache_each_layer(kv_cache_shape, self.dtype, + self.device, + self.model_config, + self.enable_torchair_graph_mode) if preemption_mode and preemption_mode == "swap": cpu_num_blocks = int(self.vllm_config.cache_config.swap_space_bytes // diff --git a/omni/layers/attention/backend/attention.py b/omni/layers/attention/backend/attention.py index 7db96c8ee..f1e3f067b 100644 --- a/omni/layers/attention/backend/attention.py +++ b/omni/layers/attention/backend/attention.py @@ -1097,27 +1097,43 @@ class AscendAttentionBackendImpl(AttentionImpl): else: # use chunked prefill for head size 192 scenario, like deepseek # paged_attention_splitfuse maybe crash at such scenario - all_key = self.key_cache.view(-1, self.num_kv_heads, self.head_size)[attn_metadata.kv_index].contiguous() - all_value = self.value_cache.view(-1, self.num_kv_heads, self.head_size)[attn_metadata.kv_index].contiguous() - actual_seq_qlen = np.array(attn_metadata.query_lens).cumsum().tolist() - actual_seq_kvlen = np.array(attn_metadata.seq_lens).cumsum().tolist() - attn_output = torch_npu.npu_fusion_attention( - query, - all_key, - all_value, - head_num=self.num_heads, + + block_size = self.value_cache.shape[1] + num_batch = attn_metadata.query_lens.shape[0] + # sink stored in block 0 + block_tables = F.pad(attn_metadata.block_tables, (1, 0, 0, 0), value=0) + # PA模式actual_seq_lengths累加,actual_seq_lengths_kv不累加;非PA模式都是累加 + cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() + cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu") + cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) + if sink_key is not None: + seq_lens_with_sink = attn_metadata.seq_lens + sink_key.shape[0] + cu_seqlen = [0] + (attn_metadata.seq_lens.tolist() if sink_key is None else \ + seq_lens_with_sink.tolist()) + cu_seqlen = torch.tensor(cu_seqlen, device="npu") + cu_seqlen = torch.cumsum(cu_seqlen, dim=0)[1:] + atten_mask = ~torch.tril( + torch.ones((2048, 2048), device='npu', dtype=torch.bool) + ) + attn_output = torch_npu.npu_fused_infer_attention_score( + query[:cu_seqlen_q[-1],:,:], + self.key_cache.view(-1, block_size, self.num_kv_heads * self.head_size), + self.value_cache.view(-1, block_size, self.num_kv_heads * v_head_size), + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, input_layout="TND", scale=self.scale, - atten_mask=AscendAttentionBackendImpl.SHARE_MASK_TRIL_SPARSE, sparse_mode=3, - actual_seq_qlen=actual_seq_qlen, - actual_seq_kvlen=actual_seq_kvlen, + pre_tokens=2147483647, + next_tokens=0, + atten_mask=atten_mask, + inner_precise=0, + block_table=block_tables[:num_batch], + block_size=block_size, + actual_seq_lengths=cu_seqlen_q[1:], + actual_seq_lengths_kv=attn_metadata.seq_lens + sink_key.shape[0], )[0] - - output = output.view_as(attn_output) - output.copy_(attn_output) - - return output.view(num_tokens, -1) + output[:cu_seqlen_q[-1], :].copy_(attn_output) class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -1195,3 +1211,24 @@ class AscendAttentionBackend(AttentionBackend): if not int(os.getenv("NO_NPU_MOCK", "0")) and device != "cpu": torch_npu.npu_format_cast(layer_kv_caches, 2) return (layer_kv_caches[0], layer_kv_caches[1]) + + @staticmethod + def init_kv_cache_each_layer_sink_attetion(kv_cache_shape, dtype, device, model_config: "ModelConfig", enable_graph_mode, v_channel) -> \ + tuple[torch.Tensor, ...]: + # KVCache needs to store the shape of the reduced dimension [num_blocks, block_size, 1, kv_lora_rank] [num_blocks, block_size, 1, rope_dim] + # The shape of the augmented dimension is [num_blocks, block_size, head_num, head_dim] + k_cache_shape = kv_cache_shape[1:] + v_cache_shape = kv_cache_shape[1:] + *rest, last = v_cache_shape + v_cache_shape = (*rest, v_channel) + + layer_k_cache = torch.zeros(k_cache_shape, + dtype=dtype if not model_extra_config.operator_opt_config.fa_quant else torch.int8, + device=device) + layer_v_cache = torch.zeros(v_cache_shape, + dtype=dtype if not model_extra_config.operator_opt_config.fa_quant else torch.int8, + device=device) + if not int(os.getenv("NO_NPU_MOCK", "0")) and device != "cpu": + torch_npu.npu_format_cast(layer_k_cache, 2) + torch_npu.npu_format_cast(layer_v_cache, 2) + return (layer_k_cache, layer_v_cache) diff --git a/tools/scripts/test_start_api_servers_pangu_72Bv2_chunked.sh b/tools/scripts/test_start_api_servers_pangu_72Bv2_chunked.sh new file mode 100644 index 000000000..fcda14c77 --- /dev/null +++ b/tools/scripts/test_start_api_servers_pangu_72Bv2_chunked.sh @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +export GLOO_SOCKET_IFNAME=enp23s0f3 +export TP_SOCKET_IFNAME=enp23s0f3 +# enp67s0f5 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +export VLLM_USE_V1=1 +export VLLM_WORKER_MULTIPROC_METHOD=fork +export VLLM_ENABLE_MC2=0 +export USING_LCCL_COM=0 +export ASCEND_LAUNCH_BLOCKING=1 +export ENABLE_PREFILL_TND=1 +export TORCHDYNAMO_VERBOSE=0 +export OMNI_USE_PANGU=1 +export PYTHONUNBUFFERED=1 +# export TORCH_LOGS="+dynamo" +# export TORCH_LOGS="" +# unset TORCH_LOGS +# export OMNI_USE_QWEN=1 +# 混部模式开启enable_hybrid_graph_mode + # --additional-config '{ "enable_hybrid_graph_mode": true}' \ +# --max-num-batched-tokens 512 +python start_api_servers.py \ + --num-servers 1 \ + --model-path /mnt/sfs_turbo/bucket-910c-6055/rrj/model/MoE_74BA15B/vllm_ascend_mtp/eighth_2550b/fix/iter_0009500/ \ + --master-ip 8.8.8.8 \ + --tp 4 \ + --master-port 35678 \ + --served-model-name PanguProMoE \ + --log-dir apiserverlog \ + --extra-args "--enable-expert-parallel --long-prefill-token-threshold 1024 " \ + --gpu-util 0.8 \ + --base-api-port 9555 \ + --max-model-len 4096 \ + --no-enable-prefix-caching \ + --additional-config '{"graph_model_compile_config":{"level":1, "use_ge_graph_cached":false}, "enable_hybrid_graph_mode": false, "expert_parallel_size": 4, "expert_tensor_parallel_size": 1}' + + +# EXTRA_ARGS里 把no enable chunked prefill删了,通过控制--max-num-batched-tokens和--long-prefill-token-threshold 实现chunked prefill +# EXTRA_ARGS='--max-num-batched-tokens 32768 --enforce-eager --no-enable-prefix-caching --enable-expert-parallel --disable-log-requests --max-num-seqs 16 --long-prefill-token-threshold 32768' + -- Gitee From 96aecbd9188930fc0f3ec4f1af123aacdcead2d1 Mon Sep 17 00:00:00 2001 From: Nickyi Date: Tue, 18 Nov 2025 10:57:45 +0800 Subject: [PATCH 6/6] update --- omni/layers/attention/backend/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omni/layers/attention/backend/attention.py b/omni/layers/attention/backend/attention.py index f1e3f067b..6656fc195 100644 --- a/omni/layers/attention/backend/attention.py +++ b/omni/layers/attention/backend/attention.py @@ -1026,7 +1026,7 @@ class AscendAttentionBackendImpl(AttentionImpl): cu_seqlen = torch.tensor(cu_seqlen, device="npu") cu_seqlen = torch.cumsum(cu_seqlen, dim=0)[1:] attn_output = torch_npu.npu_fused_infer_attention_score( - query, + query[:cu_seqlen_q[-1],:,:], key, value, num_heads=self.num_heads, @@ -1040,7 +1040,7 @@ class AscendAttentionBackendImpl(AttentionImpl): inner_precise=0, actual_seq_lengths=cu_seqlen_q[1:], actual_seq_lengths_kv=cu_seqlen)[0] - output.copy_(attn_output) + output[:cu_seqlen_q[-1], :].copy_(attn_output) output = output.view(-1, self.num_heads * v_head_size) if stream_for_reshape_and_cache != torch.npu.current_stream(): -- Gitee