From 0328ba6db42a828b60652f48b600beee36ebae99 Mon Sep 17 00:00:00 2001 From: Zhi Bowen Date: Wed, 26 Nov 2025 10:00:09 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0Qwen3-Next=E7=9A=84BF16?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=EF=BC=8C=E6=96=B0=E5=A2=9EQwen3-Next?= =?UTF-8?q?=E7=9A=84w8a8=E9=87=8F=E5=8C=96=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models/configs/best_practice_configs.json | 20 +++++++++++++++++++ .../configs/qwen3_next_bf16_a2_1p1d_d.json | 16 +++++---------- .../configs/qwen3_next_bf16_a2_1p1d_p.json | 11 ++++------ .../configs/qwen3_next_w8a8_a2_1p1d_d.json | 17 ++++++++++++++++ .../configs/qwen3_next_w8a8_a2_1p1d_p.json | 14 +++++++++++++ .../configs/qwen3_next_w8a8_a2_2p1d_d.json | 19 ++++++++++++++++++ .../configs/qwen3_next_w8a8_a2_2p1d_p.json | 14 +++++++++++++ 7 files changed, 93 insertions(+), 18 deletions(-) create mode 100644 omni/models/configs/qwen3_next_w8a8_a2_1p1d_d.json create mode 100644 omni/models/configs/qwen3_next_w8a8_a2_1p1d_p.json create mode 100644 omni/models/configs/qwen3_next_w8a8_a2_2p1d_d.json create mode 100644 omni/models/configs/qwen3_next_w8a8_a2_2p1d_p.json diff --git a/omni/models/configs/best_practice_configs.json b/omni/models/configs/best_practice_configs.json index fb05038fb..1acf3601e 100644 --- a/omni/models/configs/best_practice_configs.json +++ b/omni/models/configs/best_practice_configs.json @@ -317,6 +317,26 @@ "prefill_config_file": "qwen3_next_bf16_a2_1p1d_p.json", "decode_config_file": "qwen3_next_bf16_a2_1p1d_d.json" }, + { + "model": "qwen3-next", + "hardware": "A2", + "precision": "w8a8", + "prefill_node_num": 1, + "decode_node_num": 1, + "pd_disaggregation": true, + "prefill_config_file": "qwen3_next_w8a8_a2_1p1d_p.json", + "decode_config_file": "qwen3_next_w8a8_a2_1p1d_d.json" + }, + { + "model": "qwen3-next", + "hardware": "A2", + "precision": "w8a8", + "prefill_node_num": 2, + "decode_node_num": 1, + "pd_disaggregation": true, + "prefill_config_file": "qwen3_next_w8a8_a2_2p1d_p.json", + "decode_config_file": "qwen3_next_w8a8_a2_2p1d_d.json" + }, { "model": "qwen3-next", "hardware": "A3", diff --git a/omni/models/configs/qwen3_next_bf16_a2_1p1d_d.json b/omni/models/configs/qwen3_next_bf16_a2_1p1d_d.json index 94085babb..96acc9c54 100644 --- a/omni/models/configs/qwen3_next_bf16_a2_1p1d_d.json +++ b/omni/models/configs/qwen3_next_bf16_a2_1p1d_d.json @@ -1,24 +1,18 @@ { "model_parallel_config": { "dense_mlp_tp_size": 4, - "o_proj_tp_size": 1, - "dp_size": 1 + "o_proj_tp_size": 1 }, "operator_optimizition_config": { "enable_kv_rmsnorm_rope_cache": true, - "prefill_moe_all_to_all": false, "moe_multi_stream_tune": true, - "best_ep": false, - "merge_qkv": false, "two_stage_comm": true, "gmm_nz": true, + "unquant_bmm_nz": true, "decode_moe_dispatch_combine": false, - "opt_w2_scale_cast": false, - "enable_round_pipeline_comm": false, - "enable_pipeline_comm": false, - "pd_seperate_prefill": false, - "prefill_enable_long_seq": false, + "use_super_kernel":true, "use_prefetch": true, - "prefill_enable_mla_alltoall_local": false + "expert_gate_up_prefetch": 30, + "expert_down_prefetch": 0 } } diff --git a/omni/models/configs/qwen3_next_bf16_a2_1p1d_p.json b/omni/models/configs/qwen3_next_bf16_a2_1p1d_p.json index 3ad28eecb..8077a8867 100644 --- a/omni/models/configs/qwen3_next_bf16_a2_1p1d_p.json +++ b/omni/models/configs/qwen3_next_bf16_a2_1p1d_p.json @@ -1,17 +1,14 @@ { "model_parallel_config": { "dense_mlp_tp_size": 4, - "o_proj_tp_size": 1, - "dp_size": 1 + "o_proj_tp_size": 1 }, "operator_optimizition_config": { "enable_kv_rmsnorm_rope_cache": true, "prefill_moe_all_to_all": true, - "best_ep": false, - "merge_qkv": false, + "two_stage_comm":true, "gmm_nz": true, - "control_accept_rate": -1, - "enable_prefill_micro_batch": false, - "experts_pruning": false + "unquant_bmm_nz": true, + "prefill_enable_mla_alltoall_local": false } } \ No newline at end of file diff --git a/omni/models/configs/qwen3_next_w8a8_a2_1p1d_d.json b/omni/models/configs/qwen3_next_w8a8_a2_1p1d_d.json new file mode 100644 index 000000000..cf8beb265 --- /dev/null +++ b/omni/models/configs/qwen3_next_w8a8_a2_1p1d_d.json @@ -0,0 +1,17 @@ +{ + "model_parallel_config": { + "dense_mlp_tp_size": 4, + "o_proj_tp_size": 1 + }, + "operator_optimizition_config": { + "enable_kv_rmsnorm_rope_cache": true, + "moe_multi_stream_tune": true, + "two_stage_comm": true, + "gmm_nz": true, + "decode_moe_dispatch_combine": false, + "use_super_kernel":true, + "use_prefetch": true, + "expert_gate_up_prefetch": 50, + "expert_down_prefetch": 0 + } +} diff --git a/omni/models/configs/qwen3_next_w8a8_a2_1p1d_p.json b/omni/models/configs/qwen3_next_w8a8_a2_1p1d_p.json new file mode 100644 index 000000000..6fc1e0edf --- /dev/null +++ b/omni/models/configs/qwen3_next_w8a8_a2_1p1d_p.json @@ -0,0 +1,14 @@ +{ + "model_parallel_config": { + "dense_mlp_tp_size": 4, + "o_proj_tp_size": 1 + }, + "operator_optimizition_config": { + "enable_kv_rmsnorm_rope_cache": true, + "prefill_moe_all_to_all": true, + "two_stage_comm":true, + "gmm_nz": true, + "pd_seperate_prefill": true, + "prefill_enable_mla_alltoall_local": false + } +} \ No newline at end of file diff --git a/omni/models/configs/qwen3_next_w8a8_a2_2p1d_d.json b/omni/models/configs/qwen3_next_w8a8_a2_2p1d_d.json new file mode 100644 index 000000000..efb9923e7 --- /dev/null +++ b/omni/models/configs/qwen3_next_w8a8_a2_2p1d_d.json @@ -0,0 +1,19 @@ +{ + "model_parallel_config": { + "dense_mlp_tp_size": 4, + "o_proj_tp_size": 1 + }, + "operator_optimizition_config": { + "enable_kv_rmsnorm_rope_cache": true, + "moe_multi_stream_tune": true, + "two_stage_comm": true, + "gmm_nz": true, + "decode_moe_dispatch_combine": true, + "enable_round_pipeline_comm": false, + "enable_pipeline_comm": true, + "use_super_kernel": true, + "use_prefetch": true, + "expert_gate_up_prefetch": 50, + "expert_down_prefetch": 0 + } +} diff --git a/omni/models/configs/qwen3_next_w8a8_a2_2p1d_p.json b/omni/models/configs/qwen3_next_w8a8_a2_2p1d_p.json new file mode 100644 index 000000000..c2bb8b1d4 --- /dev/null +++ b/omni/models/configs/qwen3_next_w8a8_a2_2p1d_p.json @@ -0,0 +1,14 @@ +{ + "model_parallel_config": { + "dense_mlp_tp_size": 4, + "o_proj_tp_size": 1 + }, + "operator_optimizition_config": { + "enable_kv_rmsnorm_rope_cache": true, + "prefill_moe_all_to_all": true, + "two_stage_comm": true, + "gmm_nz": true, + "pd_seperate_prefill": true, + "prefill_enable_mla_alltoall_local": false + } +} \ No newline at end of file -- Gitee From 5625c057e1aef42b36f5072b4f9a518c01a5c41f Mon Sep 17 00:00:00 2001 From: Yu Gao Date: Thu, 27 Nov 2025 20:17:19 +0800 Subject: [PATCH 2/2] use custom ops and fix a bug --- omni/layers/attention/backend/gdn_attn.py | 4 ++ .../layers/attention/linear/sigmoid_gating.py | 2 +- omni/models/qwen/qwen3_next.py | 45 +++++++------------ 3 files changed, 21 insertions(+), 30 deletions(-) diff --git a/omni/layers/attention/backend/gdn_attn.py b/omni/layers/attention/backend/gdn_attn.py index bb3d37ba1..2d26c9b36 100644 --- a/omni/layers/attention/backend/gdn_attn.py +++ b/omni/layers/attention/backend/gdn_attn.py @@ -177,6 +177,10 @@ class GDNAttentionMetadataBuilder(AscendAttentionMetadataBuilder): spec_token_masks = None spec_state_indices_tensor = None non_spec_state_indices_tensor = self.block_table.block_table[:, 0] + if num_decodes == 0: + non_spec_state_indices_tensor[num_prefills:] = 0 + if num_prefills == 0: + non_spec_state_indices_tensor[num_decodes:] = 0 spec_query_start_loc = None non_spec_query_start_loc = query_start_loc num_accepted_tokens = None diff --git a/omni/layers/attention/linear/sigmoid_gating.py b/omni/layers/attention/linear/sigmoid_gating.py index fce3d52b0..9e72f086a 100644 --- a/omni/layers/attention/linear/sigmoid_gating.py +++ b/omni/layers/attention/linear/sigmoid_gating.py @@ -297,7 +297,7 @@ def _fused_recurrent_gated_delta_rule_npu( return o, initial_state # Placeholder for final_state def get_gdn_based_on_env(): - env_value = os.getenv("TORCH_NPU_USE_GDN_DECODE_OP", "False").strip().lower() + env_value = os.getenv("TORCH_NPU_USE_GDN_DECODE_OP", "true").strip().lower() if env_value in ("true", "1", "yes", "on"): return _fused_recurrent_gated_delta_rule_npu # Package is available else: diff --git a/omni/models/qwen/qwen3_next.py b/omni/models/qwen/qwen3_next.py index efb80bb56..89829425b 100644 --- a/omni/models/qwen/qwen3_next.py +++ b/omni/models/qwen/qwen3_next.py @@ -7,6 +7,8 @@ from math import log import torch import torch.nn.functional as F +import torch_npu +import omni_custom_ops from einops import rearrange from torch import nn from torch.library import Library @@ -246,25 +248,11 @@ def torch_chunk_gated_delta_rule( block_size = 1 attn_inv = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device).repeat((tuple(attn.shape)[:-2] + (1, 1))) attn = attn_inv - attn - for i in range(lg): - block_num = chunk_size // block_size - prod = attn @ attn_inv - attn_inv_block = attn_inv.view(tuple(attn.shape)[:-2] + (block_num, block_size, block_num, block_size)).transpose(-2, -3) - prod_block = prod.view(tuple(attn.shape)[:-2] + (block_num, block_size, block_num, block_size)).transpose(-2, -3) - r0 = torch.arange(block_num // 2, device=attn.device) * 2 - r1 = r0 + 1 - attn_inv_block[:, :, :, r1, r0, :, :] = -attn_inv_block[..., r1, r1, :, :] @ prod_block[..., r1, r0, :, :] - attn_inv = attn_inv_block.transpose(-2, -3).view(tuple(attn_inv_block.shape)[:-4] + (chunk_size, chunk_size)) - block_size *= 2 - attn = attn_inv + attn = torch_npu.npu_lower_triangular_inverse(attn) value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - last_recurrent_state = (torch.zeros(batch_size, num_v_heads, - k_head_dim, v_head_dim).to(value) if - initial_state is None else initial_state.to(value)) - core_attn_out = torch.zeros_like(value) mask = torch.triu(torch.ones(chunk_size, chunk_size, @@ -286,29 +274,28 @@ def torch_chunk_gated_delta_rule( v_new_out = torch.zeros_like(value) attn_inter_out = torch.zeros_like(value) - # for each chunk - for i in range(0, sequence_length_padded // chunk_size): - v_i = value[:, :, i] - attn = attn_score[:, :, i] - v_prime_attn_inter = (k_cumdecay_qgexp[:, :, i]) @ last_recurrent_state - v_prime = v_prime_attn_inter[:, :, :chunk_size] - attn_inter = v_prime_attn_inter[:, :, chunk_size:] - v_new = v_i - v_prime - v_new_out[:, :, i] = v_new - attn_inter_out[:, :, i] = attn_inter - last_recurrent_state *= gexp[:, :, i, -1, :, None] - last_recurrent_state += (kgexp[:, :, i]).transpose(-1, -2) @ v_new + initial_state_ = (torch.zeros(batch_size, num_v_heads, v_head_dim, k_head_dim).to(value) if initial_state is None else initial_state.to(value)) + actual_seqlens = torch.tensor([sequence_length_padded], device=qgexp.device, dtype=torch.int32) + attn_inter_out, v_new_out = torch.ops.custom.npu_chunk_gated_delta_rule_recurrence( + initial_state_.to(torch.float32), + kgexp.squeeze(0), + value.squeeze(0), + k_cumdecay.squeeze(0), + qgexp.squeeze(0), + gexp.squeeze(0), + actual_seqlens + ) core_attn_out = attn_inter_out + attn_score @ v_new_out if not output_final_state: - last_recurrent_state = None + initial_state_ = None core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) core_attn_out = core_attn_out[:, :, :sequence_length] core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state.transpose(-1,-2) + return core_attn_out, initial_state_ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): -- Gitee