From e4cb111357a81591a6b161fe87f7bb54c12d2819 Mon Sep 17 00:00:00 2001 From: imlixy Date: Thu, 20 Nov 2025 14:50:54 +0800 Subject: [PATCH 1/4] merge gittee master 1120 and qwen3_moe c8 supported --- .../models/configs/best_practice_configs.json | 10 ++ .../qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json | 22 ++++ .../qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json | 19 +++ omni/models/qwen/qwen3_moe.py | 38 ++++-- .../npu_w8a8_dynamic/npu_w8a8_dynamic.py | 117 +++++++++++++++++- 5 files changed, 194 insertions(+), 12 deletions(-) create mode 100644 omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json create mode 100644 omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json diff --git a/omni/models/configs/best_practice_configs.json b/omni/models/configs/best_practice_configs.json index a11d559118..09dbe51fe4 100644 --- a/omni/models/configs/best_practice_configs.json +++ b/omni/models/configs/best_practice_configs.json @@ -369,6 +369,16 @@ "prefill_config_file": "qwen3_235b_w8a8c16_a3_2p1d_p.json", "decode_config_file": "qwen3_235b_w8a8c16_a3_2p1d_d.json" }, + { + "model": "qwen3_30b_a3b", + "hardware": "A3", + "precision": "w8a8", + "prefill_node_num": 1, + "decode_node_num": 1, + "pd_disaggregation": true, + "prefill_config_file": "qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json", + "decode_config_file": "qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json" + }, { "model": "qwen3_30b_a3b", "hardware": "A3", diff --git a/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json b/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json new file mode 100644 index 0000000000..95dba82de4 --- /dev/null +++ b/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json @@ -0,0 +1,22 @@ +{ + "model_parallel_config": { + "dense_mlp_tp_size": 4, + "o_proj_tp_size": 4 + }, + "operator_optimizition_config": { + "enable_kv_rmsnorm_rope_cache": true, + "best_ep": false, + "merge_qkv": false, + "gmm_nz": true, + "unquant_bmm_nz": true, + "decode_moe_dispatch_combine": true, + "use_super_kernel": true, + "enable_c8": true, + "use_mlaprolog": false, + "control_accept_rate": -1, + "use_prefetch": true, + "expert_gate_up_prefetch": 16, + "expert_down_prefetch": 16, + "attn_prefetch": 96 + } +} \ No newline at end of file diff --git a/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json b/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json new file mode 100644 index 0000000000..2c9296a66e --- /dev/null +++ b/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json @@ -0,0 +1,19 @@ +{ + "model_parallel_config": { + "dense_mlp_tp_size": 16, + "o_proj_tp_size": 16 + }, + "operator_optimizition_config": { + "enable_kv_rmsnorm_rope_cache": true, + "prefill_moe_all_to_all": true, + "best_ep": false, + "merge_qkv": false, + "gmm_nz": true, + "unquant_bmm_nz": true, + "enable_c8": true, + "control_accept_rate": -1, + "enable_prefill_micro_batch": false, + "experts_pruning": false, + "enable_mlp_seq_split": false + } +} \ No newline at end of file diff --git a/omni/models/qwen/qwen3_moe.py b/omni/models/qwen/qwen3_moe.py index 02fec18352..50d22dec83 100644 --- a/omni/models/qwen/qwen3_moe.py +++ b/omni/models/qwen/qwen3_moe.py @@ -64,6 +64,7 @@ from omni.layers.linear import (RowParallelFlashCommLinear, QKVParallelFlashCommLinear) from omni.layers.rotary_embedding import get_rope from omni.layers.attention.backend.attention import AscendAttentionState +from omni.layers.attention.layer import attention_init_c8 from omni.layers.utils import ConditionalTNGScope from omni.models.config_loader.loader import model_extra_config @@ -212,14 +213,29 @@ class Qwen3MoeAttention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + + if model_extra_config.operator_opt_config.enable_c8: + Attention.__init__ = attention_init_c8 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + total_num_kv_heads=self.total_num_kv_heads + ) + else: + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn" + ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -393,8 +409,8 @@ class Qwen3MoeModel(nn.Module): config = config cache_config = cache_config - quant_config = quant_config + self.quant_config = quant_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config @@ -546,7 +562,9 @@ class Qwen3MoeModel(nn.Module): expert_id=expert_id) break else: - + if ("k_scale" in name) or ("v_scale" in name): + name = self.quant_config.get_cache_scale(name) + loaded_weight = loaded_weight.view(-1) if is_pp_missing_parameter(name, self): continue if name not in params_dict: diff --git a/omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py b/omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py index b19c4102dc..340e03c334 100644 --- a/omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py +++ b/omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py @@ -31,12 +31,14 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped ) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.parameter import ( ModelWeightParameter, ChannelQuantScaleParameter ) from vllm.model_executor.utils import set_weight_attrs from vllm.distributed import get_tp_group, get_dp_group, get_ep_group +from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from omni.layers.fused_mlp import FusedMLP, FusedMLPMethodBase, W8A8DynamicFusedMLPMethod from omni.layers.linear import ( @@ -53,6 +55,7 @@ from omni.adaptors.vllm.distributed.parallel_state import( from omni.models.config_loader.loader import model_extra_config +SUPPORTED_KV_QUANT_STRATEGY = ["channel"] logger = init_logger(__name__) @@ -63,10 +66,12 @@ class NpuW8A8DynamicConfig(QuantizationConfig): def __init__( self, ignored_layers: Optional[List[str]] = None, - ignore: Optional[List[str]] = None + ignore: Optional[List[str]] = None, + kv_cache_scheme: Optional[Dict[str, Any]] = None ) -> None: self.ignored_layers = ignored_layers or [] self.ignore = ignore + self.kv_cache_scheme = kv_cache_scheme @classmethod def get_name(cls) -> str: @@ -97,10 +102,12 @@ class NpuW8A8DynamicConfig(QuantizationConfig): quant_method = cls.get_from_keys(config, ['quant_method']) ignored_layers = cls.get_from_keys_or(config, ['ignored_layers'], None) ignore = cls.get_from_keys_or(config, ['ignore'], []) - return cls(ignored_layers=ignored_layers, ignore=ignore) + kv_cache_scheme = cls.get_from_keys_or(config, ['kv_cache_scheme'], None) + return cls(ignored_layers=ignored_layers, ignore=ignore, kv_cache_scheme=kv_cache_scheme) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention if isinstance(layer, FlashCommLinearBase) or isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedFlashCommLinearMethod() @@ -109,12 +116,27 @@ class NpuW8A8DynamicConfig(QuantizationConfig): return W8A8DynamicFusedMLPMethod(self) elif isinstance(layer, FusedMoE): return NpuW8A8DynamicFusedMoEMethod(self) + elif isinstance(layer, Attention) and model_extra_config.operator_opt_config.enable_c8: + return NpuW8A8DynamicKVCacheMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in npu_w8a8_dynamic. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".k_scale"): + return name.replace(".k_scale", ".attn.k_scale") + if name.endswith(".v_scale"): + return name.replace(".v_scale", ".attn.v_scale") + # If no matches, return None return None @@ -730,3 +752,94 @@ class NpuW8A8DynamicFusedMoEMethod(FusedMoEMethodBase): y = get_ep_group().reduce_scatter(y) return y + +class NpuW8A8DynamicKVCacheMethod(BaseKVCacheMethod): + """Supports loading kv-cache scaling factors from npu_w8a8_dynamic checkpoints. + KVCache method for NPU W8A8 Dynamic. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: NpuW8A8DynamicConfig): + self.quant_config = quant_config + self.validate_kv_cache_scheme(quant_config.kv_cache_scheme) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]): + """ + Validator for the kv cache scheme. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_scheme: the npu_w8a8_dynamic kv cache scheme + """ + + if kv_cache_scheme is None: + raise ValueError("When enable kv cache quantization, " + "kv_cache_scheme must not be null in config.json") + + type_ = kv_cache_scheme.get("type") + num_bits = kv_cache_scheme.get("num_bits") + + if type_ != "int" and num_bits != 8: + raise NotImplementedError( + "Currently supported kv cache quantization is " + "num_bits=8, type=int, however " + f"received num_bits={num_bits}, type={type_}") + + strategy = kv_cache_scheme.get("strategy") + if strategy not in SUPPORTED_KV_QUANT_STRATEGY: + raise NotImplementedError( + f"Only support {SUPPORTED_KV_QUANT_STRATEGY} scaling factor " + f"for npu_w8a8_dynamic KV cache, found strategy: {strategy}") + + is_symmetric = kv_cache_scheme.get("symmetric") + if not is_symmetric: + raise NotImplementedError( + "Only support symmetric scaling factor " + "for npu_w8a8_dynamic KV cache. " + f"However found symmetric: {is_symmetric}") + + def create_weights(self, + layer: torch.nn.Module, + total_num_kv_heads: int, + head_size: int + ): + """ + Create "k_scale" and "v_scale" + for an attention layer. + """ + if self.quant_config.kv_cache_scheme is not None: + strategy = self.quant_config.kv_cache_scheme.get("strategy") + if strategy in SUPPORTED_KV_QUANT_STRATEGY: + self.total_num_kv_heads = total_num_kv_heads + scale_num = total_num_kv_heads * head_size + layer.k_scale = torch.nn.Parameter(torch.ones(scale_num, dtype=torch.get_default_dtype(), device='npu'), + requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.ones(scale_num, dtype=torch.get_default_dtype(), device='npu'), + requires_grad=False) + + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.quant_config.kv_cache_scheme is not None: + scale_num = layer.k_scale.shape[0] + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + slice_size = scale_num // min(self.total_num_kv_heads, tp_size) + num_kv_head_replicas = tp_size // self.total_num_kv_heads if tp_size >= self.total_num_kv_heads else 1 + local_index = tp_rank // num_kv_head_replicas + slice_start = local_index * slice_size + slice_end = slice_start + slice_size + + strategy = self.quant_config.kv_cache_scheme.get("strategy") + if strategy in SUPPORTED_KV_QUANT_STRATEGY: + layer.k_scale = torch.nn.Parameter(layer.k_scale[slice_start:slice_end].view(1, -1), + requires_grad=False) + layer.v_scale = torch.nn.Parameter(layer.v_scale[slice_start:slice_end].view(1, -1), + requires_grad=False) + + layer.k_scale_reciprocal = torch.nn.Parameter(1/layer.k_scale.to(torch.float32), + requires_grad=False) + layer.v_scale_reciprocal = torch.nn.Parameter(1/layer.v_scale.to(torch.float32), + requires_grad=False) \ No newline at end of file -- Gitee From 9d5d41a11ab5abdca23db20f32f37180afc50188 Mon Sep 17 00:00:00 2001 From: imlixy Date: Thu, 20 Nov 2025 15:17:14 +0800 Subject: [PATCH 2/4] merge gittee master 1120 and qwen3_moe c8 supported --- .../models/configs => }/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json | 0 .../models/configs => }/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json | 1 + 2 files changed, 1 insertion(+) rename omni/models/configs/{omni/models/configs => }/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json (100%) rename omni/models/configs/{omni/models/configs => }/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json (97%) diff --git a/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json similarity index 100% rename from omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json rename to omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json diff --git a/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json similarity index 97% rename from omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json rename to omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json index 2c9296a66e..e83fce8636 100644 --- a/omni/models/configs/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json +++ b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json @@ -11,6 +11,7 @@ "gmm_nz": true, "unquant_bmm_nz": true, "enable_c8": true, + "" "control_accept_rate": -1, "enable_prefill_micro_batch": false, "experts_pruning": false, -- Gitee From d1945ae88c76d177d5b442c28217e07472eedf64 Mon Sep 17 00:00:00 2001 From: imlixy Date: Thu, 20 Nov 2025 20:09:34 +0800 Subject: [PATCH 3/4] merge gittee master 1120 and qwen3_moe c8 supported --- omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json | 1 + omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json index 95dba82de4..5a5877d68f 100644 --- a/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json +++ b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json @@ -12,6 +12,7 @@ "decode_moe_dispatch_combine": true, "use_super_kernel": true, "enable_c8": true, + "use_tnd_pa":true, "use_mlaprolog": false, "control_accept_rate": -1, "use_prefetch": true, diff --git a/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json index e83fce8636..cf4c5cf103 100644 --- a/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json +++ b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_p.json @@ -11,7 +11,7 @@ "gmm_nz": true, "unquant_bmm_nz": true, "enable_c8": true, - "" + "use_tnd_pa":true, "control_accept_rate": -1, "enable_prefill_micro_batch": false, "experts_pruning": false, -- Gitee From d647fa67399bf323f0d755020f39c0fbe38fdd6a Mon Sep 17 00:00:00 2001 From: imlixy Date: Fri, 21 Nov 2025 15:39:16 +0800 Subject: [PATCH 4/4] support c8 for both name k_scale and kv_cache_scale --- omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json | 2 +- omni/models/qwen/qwen3_moe.py | 8 +++++++- omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py | 4 ++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json index 5a5877d68f..402ebf429f 100644 --- a/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json +++ b/omni/models/configs/qwen3_30b_a3b_w8a8c8_a3_1p1d_d.json @@ -9,7 +9,7 @@ "merge_qkv": false, "gmm_nz": true, "unquant_bmm_nz": true, - "decode_moe_dispatch_combine": true, + "decode_moe_dispatch_combine": false, "use_super_kernel": true, "enable_c8": true, "use_tnd_pa":true, diff --git a/omni/models/qwen/qwen3_moe.py b/omni/models/qwen/qwen3_moe.py index 50d22dec83..c2453d5f61 100644 --- a/omni/models/qwen/qwen3_moe.py +++ b/omni/models/qwen/qwen3_moe.py @@ -510,6 +510,12 @@ class Qwen3MoeModel(nn.Module): ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts) + kv_scale_mapping = { + "k_scale", + "v_scale", + "kv_cache_scale", + } + params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -562,7 +568,7 @@ class Qwen3MoeModel(nn.Module): expert_id=expert_id) break else: - if ("k_scale" in name) or ("v_scale" in name): + if any(key in name for key in kv_scale_mapping): name = self.quant_config.get_cache_scale(name) loaded_weight = loaded_weight.view(-1) if is_pp_missing_parameter(name, self): diff --git a/omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py b/omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py index 340e03c334..f45ff821b6 100644 --- a/omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py +++ b/omni/quantization/npu_w8a8_dynamic/npu_w8a8_dynamic.py @@ -136,6 +136,10 @@ class NpuW8A8DynamicConfig(QuantizationConfig): return name.replace(".k_scale", ".attn.k_scale") if name.endswith(".v_scale"): return name.replace(".v_scale", ".attn.v_scale") + if name.endswith(".kv_cache_scale") and ".k_proj" in name: + return name.replace(".k_proj.kv_cache_scale", ".attn.k_scale") + if name.endswith(".kv_cache_scale") and ".v_proj" in name: + return name.replace(".v_proj.kv_cache_scale", ".attn.v_scale") # If no matches, return None return None -- Gitee