From 99505640a1d3414790f5639ff3cae24245d78cd4 Mon Sep 17 00:00:00 2001 From: lz9848 <2263941766@qq.com> Date: Thu, 27 Nov 2025 16:36:14 +0800 Subject: [PATCH 1/2] pp support omni-attn & support a2 w4a8c8 --- infer_engines/bash_install_code.sh | 3 +- omni/accelerators/cache/pd.py | 132 ++++++++++++------ omni/accelerators/pd/llmdatadist_manager.py | 10 +- omni/adaptors/vllm/patches/pp_scheduler.patch | 67 +++++++++ omni/layers/attention/deepseek_mla.py | 3 +- omni/layers/moe/deepseek_moe.py | 2 +- omni/models/deepseek/deepseek_v3_a2.py | 4 +- omni/models/pangu/pangu_ultra_moe_a2.py | 2 +- 8 files changed, 168 insertions(+), 55 deletions(-) create mode 100644 omni/adaptors/vllm/patches/pp_scheduler.patch diff --git a/infer_engines/bash_install_code.sh b/infer_engines/bash_install_code.sh index f8a641adce..2e9c8e20e6 100644 --- a/infer_engines/bash_install_code.sh +++ b/infer_engines/bash_install_code.sh @@ -43,4 +43,5 @@ git apply --whitespace=nowarn $PATCH_ROOT/patch_reasoning_thinking_bug.patch git apply --whitespace=nowarn $PATCH_ROOT/tracing.patch git apply --whitespace=nowarn $PATCH_ROOT/support_v1_priority_schedule_for_xiaoyi.patch git apply --whitespace=nowarn $PATCH_ROOT/gpt_oss_model_init.patch -git apply --whitespace=nowarn $PATCH_ROOT/openai_harmony_parser.patch \ No newline at end of file +git apply --whitespace=nowarn $PATCH_ROOT/openai_harmony_parser.patch +git apply --whitespace=nowarn $PATCH_ROOT/pp_scheduler.patch diff --git a/omni/accelerators/cache/pd.py b/omni/accelerators/cache/pd.py index 4db0382151..159b7bba94 100644 --- a/omni/accelerators/cache/pd.py +++ b/omni/accelerators/cache/pd.py @@ -12,20 +12,31 @@ from . import kv_cache_interface as itfc class OmniBiGroupDataDistManager(LLMDataDistManager): def __init__(self, vllm_config): super().__init__(vllm_config) - self.registerd_kv_caches: list[list[Cache]] = [[], []] + self.registered_kv_caches: list[list[Cache]] = [[], []] @override def register_memory(self, kv_caches: dict[str, torch.Tensor]): - if any(len(group_cache) > 0 for group_cache in self.registerd_kv_caches): - raise ValueError("Attr `registerd_kv_caches` must be empty before register kv_caches.") + if any(len(group_cache) > 0 for group_cache in self.registered_kv_caches): + raise ValueError("Attr `registered_kv_caches` must be empty before register kv_caches.") # NOTE: flatten_kv_caches is a nested list like [[k1,k2,...,kL], [v1,v2,...,vL]] # if KV is just one tensor, then it's [[kv1,kv2,...,kvL]] flatten_kv_caches: list[list[torch.Tensor]] = unzip_kv_cache_dict(kv_caches) num_layers = len(flatten_kv_caches[0]) + PATTERN = itfc.PATTERN # partition layer indices into full and omni - full_layer_idx = [i for i in range(num_layers) if itfc.PATTERN[i] == 0] - omni_layer_idx = [i for i in range(num_layers) if itfc.PATTERN[i] == 1] + if self.data_dist_config.is_prefill: + # 1. 获取rank + # 2. 拿到对应的stage的layer_start和end + # 3. PATTERN + pp_rank = self.rank // self.prefill_tp_dp_size + prefill_pp_partitions = self.data_dist_config.kv_producer_pp_partitions + pp_start_layer_idx = sum(prefill_pp_partitions[:pp_rank]) + pp_end_layer_idx = pp_start_layer_idx + prefill_pp_partitions[pp_rank] + PATTERN = itfc.PATTERN[pp_start_layer_idx : pp_end_layer_idx] + + full_layer_idx = [i for i in range(num_layers) if PATTERN[i] == 0] + omni_layer_idx = [i for i in range(num_layers) if PATTERN[i] == 1] layer_idx = [full_layer_idx, omni_layer_idx] # check validity @@ -48,6 +59,13 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): logger.warning("Trying to register grouped KV caches for OMNI attention, with " f"{len(full_layer_idx)} full attn layers and {len(omni_layer_idx)} omni attn layers.") + if self.data_dist_config.is_prefill: + self._register_caches_prefill(flatten_kv_caches, layer_idx) + else: + self._register_caches_decode(flatten_kv_caches, layer_idx) + logger.error(f" ***** registered_kv_caches num:{sum([len(group_kv_caches) for group_kv_caches in self.registered_kv_caches])}") + + def _register_caches_prefill(self, flatten_kv_caches, layer_idx): # model_id related N = len(flatten_kv_caches) used_ids = set() @@ -55,27 +73,49 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): for model_id, sub_kv_caches in enumerate(flatten_kv_caches): # sub_kv_caches is a list of Tensors, whose length is number of layers - for flag in range(len(self.registerd_kv_caches)): + for flag in range(len(self.registered_kv_caches)): group_kv_caches = [sub_kv_caches[j] for j in layer_idx[flag]] cache_desc = CacheDesc(num_tensors=len(group_kv_caches), shape=tuple(group_kv_caches[0].shape), data_type=TORCH_DTYPE_TO_NPU_DTYPE[group_kv_caches[0].dtype]) cache_addrs = [int(item.data_ptr()) for item in group_kv_caches] - if self.data_dist_config.is_prefill: - # NOTE: when assigning model_id to cache_key, we consider KV group information - # e.g., if registered_kv_caches = [[K_full, V_full], [K_omni, V_omni]] - # then model_ids should be [[0, 1], [2, 3]] - cur_id = flag * N + model_id - if cur_id in used_ids: - raise RuntimeError(f"Error! ID already used. {N=}, {model_id=}, {used_ids=}, {cur_id=}.") - used_ids.add(cur_id) - cache_key = BlocksCacheKey(self.data_dist_engine.cluster_id, model_id=cur_id) - else: - cache_key = None + # NOTE: when assigning model_id to cache_key, we consider KV group information + # e.g., if registered_kv_caches = [[K_full, V_full], [K_omni, V_omni]] + # then model_ids should be [[0, 1], [2, 3]] + cur_id = flag * N + model_id + if cur_id in used_ids: + raise RuntimeError(f"Error! ID already used. {N=}, {model_id=}, {used_ids=}, {cur_id=}.") + used_ids.add(cur_id) + cache_key = BlocksCacheKey(self.data_dist_engine.cluster_id, model_id=cur_id) cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, cache_key) - self.registerd_kv_caches[flag].append(cache) - logger.error(f" ***** registerd_kv_caches num:{sum([len(group_kv_caches) for group_kv_caches in self.registerd_kv_caches])}") + self.registered_kv_caches[flag].append(cache) + + def _register_caches_decode(self, flatten_kv_caches, layer_idx): + prefill_pp_partitions = self.data_dist_config.kv_producer_pp_partitions + for flag in range(len(self.registered_kv_caches)): + cnt_layer_num = 0 + layer_idx_start = 0 + for cur_pp_stage_layer_num in prefill_pp_partitions: + cur_pp_stage_kv_caches = [] + layer_idx_end = layer_idx_start + while layer_idx_end < len(layer_idx[flag]) and cnt_layer_num <= layer_idx[flag][layer_idx_end] < cnt_layer_num + cur_pp_stage_layer_num: + layer_idx_end += 1 + flag_stage_layer_idx = layer_idx[flag][layer_idx_start : layer_idx_end] + layer_idx_start = layer_idx_end + for sub_kv_caches in flatten_kv_caches: + # sub_kv_caches is a list of Tensors, whose length is number of layers + # flag_stage_layer_idx = layer_idx[flag][cnt_layer_num : cnt_layer_num + cur_pp_stage_layer_num] + group_kv_caches = [sub_kv_caches[j] for j in flag_stage_layer_idx] + # group_kv_caches = [sub_kv_caches[j] for j in layer_idx[flag] if cnt_layer_num <= j < cnt_layer_num + cur_pp_stage_layer_num] + cache_desc = CacheDesc(num_tensors=len(group_kv_caches), shape=tuple(group_kv_caches[0].shape), + data_type=TORCH_DTYPE_TO_NPU_DTYPE[group_kv_caches[0].dtype]) + cache_addrs = [int(item.data_ptr()) for item in group_kv_caches] + + cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, None) + cur_pp_stage_kv_caches.append(cache) + self.registered_kv_caches[flag].append(cur_pp_stage_kv_caches) + cnt_layer_num += cur_pp_stage_layer_num @override def pull_kv(self, src_blocks: list[int], tgt_blocks: list[list[int]], prompt_cluster_id: int): @@ -90,34 +130,34 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): torch.npu.set_device(f"npu:{self.local_rank}") sink, recent = itfc.SINK, itfc.RECENT omni_max_blocks = sink + recent - N = len(self.registerd_kv_caches[0]) - used_ids = set() + N = len(self.registered_kv_caches[0]) + # used_ids = set() - for flag in range(len(self.registerd_kv_caches)): + for flag in range(len(self.registered_kv_caches)): group_src_blocks: list[int] = src_blocks[flag] group_tgt_blocks: list[int] = tgt_blocks[flag] - for model_id, kv_cache in enumerate(self.registerd_kv_caches[flag]): - cur_id = flag * N + model_id - if cur_id in used_ids: - raise RuntimeError(f"Error! ID already pulled. {N=}, {model_id=}, {used_ids=}, {cur_id=}.") - used_ids.add(cur_id) + for pp_stage_ind, cur_pp_stage_kv_caches in enumerate(self.registered_kv_caches[flag]): + for model_id, kv_cache in enumerate(cur_pp_stage_kv_caches): + cur_id = flag * N + model_id + cluster_id_pp_offset = pp_stage_ind * self.prefill_tp_dp_size - prompt_cache_key = BlocksCacheKey( - prompt_cluster_id=prompt_cluster_id, model_id=cur_id) - if flag == 0: - self._pull_blocks(prompt_cache_key, kv_cache, - group_src_blocks, group_tgt_blocks) - else: - if len(group_tgt_blocks) == 0: - continue - tmp_src, tmp_tgt = group_src_blocks, group_tgt_blocks - if len(group_src_blocks) < omni_max_blocks: - tmp_tgt = group_tgt_blocks[:len(group_src_blocks)] - elif len(group_src_blocks) > omni_max_blocks: - tmp_src = group_src_blocks[:sink] + group_src_blocks[-recent:] - if len(tmp_src) != len(tmp_tgt): - raise RuntimeError("src and tgt cannot match for omni kv caches. " - f"{src_blocks=}, {tgt_blocks=}, " - f"{len(tmp_src)=}, {len(tmp_tgt)=}.") - self._pull_blocks(prompt_cache_key, kv_cache, - tmp_src, tmp_tgt) + prompt_cache_key = BlocksCacheKey( + prompt_cluster_id=prompt_cluster_id + cluster_id_pp_offset, model_id=cur_id) + if flag == 0: + self._pull_blocks(prompt_cache_key, kv_cache, + group_src_blocks, group_tgt_blocks) + else: + if len(group_tgt_blocks) == 0: + continue + tmp_src, tmp_tgt = group_src_blocks, group_tgt_blocks + if len(group_src_blocks) < omni_max_blocks: + tmp_tgt = group_tgt_blocks[:len(group_src_blocks)] + elif len(group_src_blocks) > omni_max_blocks: + tmp_src = group_src_blocks[:sink] + group_src_blocks[-recent:] + if len(tmp_src) != len(tmp_tgt): + raise RuntimeError("src and tgt cannot match for omni kv caches. " + f"{src_blocks=}, {tgt_blocks=}, " + f"{tmp_src=}, {tmp_tgt=}, " + f"{len(tmp_src)=}, {len(tmp_tgt)=}.") + self._pull_blocks(prompt_cache_key, kv_cache, + tmp_src, tmp_tgt) diff --git a/omni/accelerators/pd/llmdatadist_manager.py b/omni/accelerators/pd/llmdatadist_manager.py index e574a0812f..76f182fe1e 100644 --- a/omni/accelerators/pd/llmdatadist_manager.py +++ b/omni/accelerators/pd/llmdatadist_manager.py @@ -157,6 +157,12 @@ class LLMDataDistManager: self.registered_link_infos = {} + if self.data_dist_config.is_prefill: + prefill_server_groups = [self.data_dist_config.local_group] + else: + prefill_server_groups = self.data_dist_config.global_rank_table.prefill_group + self.prefill_tp_dp_size = len(prefill_server_groups[0].device_list) // self.data_dist_config.kv_producer_pp_size + def get_real_remote_cluster_ids(self, meta: "ReqMeta"): remote_cluster_ids = self.registered_link_infos.get( (meta.remote_cluster_id, meta.remote_dp_rank, self.rank), None) @@ -283,9 +289,9 @@ class LLMDataDistManager: prefill_tp_dp_size = len(prefill_server_groups[0].device_list) // self.data_dist_config.kv_producer_pp_size for pp_stage_ind, cur_pp_stage_kv_caches in enumerate(self.registered_kv_caches): for model_id, kv_cache in enumerate(cur_pp_stage_kv_caches): - cluster_id_pp_offset = pp_stage_ind * prefill_tp_dp_size + cluster_id_pp_offset = pp_stage_ind * self.prefill_tp_dp_size prompt_cache_key = BlocksCacheKey( - prompt_cluster_id=prompt_cluster_id +cluster_id_pp_offset, model_id=model_id + prompt_cluster_id=prompt_cluster_id + cluster_id_pp_offset, model_id=model_id ) self._pull_blocks(prompt_cache_key, kv_cache, src_blocks, tgt_blocks) diff --git a/omni/adaptors/vllm/patches/pp_scheduler.patch b/omni/adaptors/vllm/patches/pp_scheduler.patch new file mode 100644 index 0000000000..d2ddd204ab --- /dev/null +++ b/omni/adaptors/vllm/patches/pp_scheduler.patch @@ -0,0 +1,67 @@ +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 1f01c1e0a..a58ab11c5 100755 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -91,6 +91,7 @@ class Scheduler(SchedulerInterface): + + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs ++ self.max_num_scheduled_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len +@@ -264,6 +265,8 @@ class Scheduler(SchedulerInterface): + blocks_to_swap_out: list[list[tuple[int, int]]] = [[] for _ in range(num_groups)] + blocks_to_swap_in: list[list[tuple[int, int]]] = [[] for _ in range(num_groups)] + ++ num_scheduled_reqs = 0 ++ + # For logging. + scheduled_timestamp = time.monotonic() + +@@ -374,6 +377,7 @@ class Scheduler(SchedulerInterface): + + # Schedule the request. + scheduled_running_reqs.append(request) ++ num_scheduled_reqs += 1 + if request.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. +@@ -434,9 +438,10 @@ class Scheduler(SchedulerInterface): + skipped_waiting_requests = create_request_queue(self.policy) + + # Next, schedule the WAITING requests. ++ pp_size = self.vllm_config.parallel_config.pipeline_parallel_size + if not preempted_reqs: + while self.waiting and token_budget > 0: +- if len(self.running) == self.max_num_running_reqs: ++ if (pp_size == 1 and len(self.running) == self.max_num_running_reqs) or num_scheduled_reqs == self.max_num_scheduled_reqs: + break + + request = self.waiting.peek_request() +@@ -596,6 +601,7 @@ class Scheduler(SchedulerInterface): + else: + raise RuntimeError( + f"Invalid request status: {request.status}") ++ num_scheduled_reqs += 1 + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) +@@ -625,12 +631,15 @@ class Scheduler(SchedulerInterface): + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 +- assert len(self.running) <= self.max_num_running_reqs ++ if pp_size == 1: ++ assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + +- len(scheduled_running_reqs) <= len(self.running)) ++ len(scheduled_running_reqs) == num_scheduled_reqs) ++ assert num_scheduled_reqs <= len(self.running) ++ assert num_scheduled_reqs <= self.max_num_scheduled_reqs + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. diff --git a/omni/layers/attention/deepseek_mla.py b/omni/layers/attention/deepseek_mla.py index 2ab1d27706..4361341caf 100644 --- a/omni/layers/attention/deepseek_mla.py +++ b/omni/layers/attention/deepseek_mla.py @@ -1412,6 +1412,7 @@ class DeepseekMLA(nn.Module): attn_metadata.slot_mapping, kv_cache[1], kv_cache[0], + c_kv_scale=self.kv_scale_reci_tile, epsilon=self.kv_a_layernorm.variance_epsilon, cache_mode="PA_NZ", is_output_kv=True) @@ -1470,7 +1471,7 @@ class DeepseekMLA(nn.Module): ): prefill_q = q[computed_tokens:computed_tokens + actual_seq_qlen[-1]] if prefill_metadata.kv_index_list and kv_cache is not None and isinstance(kv_cache, Tuple) and \ - kv_cache[0].numel() > 0: + kv_cache[0].numel() > 0 and not self.fa_quant: block_num, block_size, head_size, _ = kv_cache[0].shape kv_cache_a = (kv_cache[0] diff --git a/omni/layers/moe/deepseek_moe.py b/omni/layers/moe/deepseek_moe.py index cfb40203a2..a6b2042fd3 100644 --- a/omni/layers/moe/deepseek_moe.py +++ b/omni/layers/moe/deepseek_moe.py @@ -325,7 +325,7 @@ class DeepseekMoE(nn.Module): from omni.accelerators.placement.omni_placement.omni_planner import OmniPlanner self.planner = OmniPlanner(device="npu", rank=get_world_group().rank_in_group, - world_size=get_world_group().world_size, + world_size=get_ep_group().world_size, num_experts=self.n_routed_experts, num_redundancy_shared_expert_rank=self.redundancy_shared_expert_num) self.moe_layer_idx = OmniPlanner.get_deepseek_v3_moe_layer_idx(moe_prefix, first_k_dense_replace=self.first_k_dense_replace) diff --git a/omni/models/deepseek/deepseek_v3_a2.py b/omni/models/deepseek/deepseek_v3_a2.py index 4b8a26e39e..2977595fdf 100644 --- a/omni/models/deepseek/deepseek_v3_a2.py +++ b/omni/models/deepseek/deepseek_v3_a2.py @@ -333,10 +333,8 @@ class DeepseekDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - # hidden_states, residual = self.input_layernorm( - # hidden_states, residual, quant_symbol=True) hidden_states, residual = self.input_layernorm( - hidden_states, residual, quant_symbol=True) + hidden_states, residual, quant_symbol=(not model_extra_config.operator_opt_config.use_mlaprolog and not model_extra_config.operator_opt_config.enable_dsa)) hidden_states = self.self_attn( positions=positions, diff --git a/omni/models/pangu/pangu_ultra_moe_a2.py b/omni/models/pangu/pangu_ultra_moe_a2.py index a5644444c5..a7926939ce 100644 --- a/omni/models/pangu/pangu_ultra_moe_a2.py +++ b/omni/models/pangu/pangu_ultra_moe_a2.py @@ -332,7 +332,7 @@ class PanguUltraMoEDecoderLayer(nn.Module): # Adapt: adapt for w8a8 dynamic, do quant # Combines residual add and rmsnorm hidden_states, residual = self.input_layernorm( - hidden_states, residual, quant_symbol=True) + hidden_states, residual, quant_symbol=(not model_extra_config.operator_opt_config.use_mlaprolog and not model_extra_config.operator_opt_config.enable_dsa)) hidden_states = self.self_attn( positions=positions, -- Gitee From 140481b75c1e499b636eb5018857f0c4a3f38b52 Mon Sep 17 00:00:00 2001 From: lizan <13590736+lizan9848@user.noreply.gitee.com> Date: Fri, 28 Nov 2025 01:30:05 +0000 Subject: [PATCH 2/2] update omni/accelerators/pd/llmdatadist_manager.py. Signed-off-by: lizan <13590736+lizan9848@user.noreply.gitee.com> --- omni/accelerators/pd/llmdatadist_manager.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/omni/accelerators/pd/llmdatadist_manager.py b/omni/accelerators/pd/llmdatadist_manager.py index 76f182fe1e..35bca396ac 100644 --- a/omni/accelerators/pd/llmdatadist_manager.py +++ b/omni/accelerators/pd/llmdatadist_manager.py @@ -282,11 +282,6 @@ class LLMDataDistManager: # The preliminary reason is that the context is lost when multiple coroutines pull kv. torch.npu.set_device(f"npu:{self.local_rank}") if self.data_dist_config.kv_producer_pp_size > 1: - if self.data_dist_config.is_prefill: - prefill_server_groups = [self.data_dist_config.local_group] - else: - prefill_server_groups = self.data_dist_config.global_rank_table.prefill_group - prefill_tp_dp_size = len(prefill_server_groups[0].device_list) // self.data_dist_config.kv_producer_pp_size for pp_stage_ind, cur_pp_stage_kv_caches in enumerate(self.registered_kv_caches): for model_id, kv_cache in enumerate(cur_pp_stage_kv_caches): cluster_id_pp_offset = pp_stage_ind * self.prefill_tp_dp_size -- Gitee