From 2fbaa09eb76f93bfb909074492a7c220d8ac3886 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Mon, 24 Nov 2025 14:05:51 +0800 Subject: [PATCH 01/15] support new datadist for dynamic build new links / rebuild broken links, missed some detailed implementations --- omni/accelerators/cache/pd.py | 54 +- omni/accelerators/pd/__init__.py | 3 +- .../pd/llmdatadist_connector_v2.py | 967 ++++++++++++++++++ .../accelerators/pd/llmdatadist_manager_v1.py | 468 +++++++++ 4 files changed, 1478 insertions(+), 14 deletions(-) create mode 100644 omni/accelerators/pd/llmdatadist_connector_v2.py create mode 100644 omni/accelerators/pd/llmdatadist_manager_v1.py diff --git a/omni/accelerators/cache/pd.py b/omni/accelerators/cache/pd.py index 4db038215..ebbcdc612 100644 --- a/omni/accelerators/cache/pd.py +++ b/omni/accelerators/cache/pd.py @@ -1,17 +1,30 @@ from typing_extensions import override import torch from llm_datadist.v2.llm_types import Cache, CacheDesc, BlocksCacheKey -from omni.accelerators.pd.llmdatadist_manager import ( - LLMDataDistManager, - TORCH_DTYPE_TO_NPU_DTYPE, - unzip_kv_cache_dict, - logger, -) +import os +USE_NEW_DATADIST = os.getenv("ENABLE_DYNAMIC_DATADIST", "0") == "1" +if USE_NEW_DATADIST: + from omni.accelerators.pd.llmdatadist_manager_v1 import ( + LLMDataDistManager, + TORCH_DTYPE_TO_NPU_DTYPE, + unzip_kv_cache_dict, + logger, + ) +else: + from omni.accelerators.pd.llmdatadist_manager import ( + LLMDataDistManager, + TORCH_DTYPE_TO_NPU_DTYPE, + unzip_kv_cache_dict, + logger, + ) from . import kv_cache_interface as itfc class OmniBiGroupDataDistManager(LLMDataDistManager): - def __init__(self, vllm_config): - super().__init__(vllm_config) + def __init__(self, vllm_config, local_host_ip=0, host_port=0): + if USE_NEW_DATADIST: + super().__init__(vllm_config, local_host_ip, host_port) + else: + super().__init__(vllm_config) self.registerd_kv_caches: list[list[Cache]] = [[], []] @override @@ -78,7 +91,7 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): logger.error(f" ***** registerd_kv_caches num:{sum([len(group_kv_caches) for group_kv_caches in self.registerd_kv_caches])}") @override - def pull_kv(self, src_blocks: list[int], tgt_blocks: list[list[int]], prompt_cluster_id: int): + def pull_kv(self, src_blocks: list[int], tgt_blocks: list[list[int]], prompt_cluster_id: int, prefill_dp_rank: int=0): """Pull KV Caches for both full and omni attention layers. The input `tgt_blocks` is a list of lists of ints like [[blk1,...,blk100], [blk1,blk2,blk3]], where the first sublist is the block table for full attention layers while the second is @@ -105,8 +118,12 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): prompt_cache_key = BlocksCacheKey( prompt_cluster_id=prompt_cluster_id, model_id=cur_id) if flag == 0: - self._pull_blocks(prompt_cache_key, kv_cache, - group_src_blocks, group_tgt_blocks) + if USE_NEW_DATADIST: + ret = self._pull_blocks(prompt_cache_key, kv_cache, + group_src_blocks, group_tgt_blocks) + else: + self._pull_blocks(prompt_cache_key, kv_cache, + group_src_blocks, group_tgt_blocks) else: if len(group_tgt_blocks) == 0: continue @@ -119,5 +136,16 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): raise RuntimeError("src and tgt cannot match for omni kv caches. " f"{src_blocks=}, {tgt_blocks=}, " f"{len(tmp_src)=}, {len(tmp_tgt)=}.") - self._pull_blocks(prompt_cache_key, kv_cache, - tmp_src, tmp_tgt) + if USE_NEW_DATADIST: + ret = self._pull_blocks(prompt_cache_key, kv_cache, + tmp_src, tmp_tgt) + else: + self._pull_blocks(prompt_cache_key, kv_cache, + tmp_src, tmp_tgt) + if USE_NEW_DATADIST: + if not ret: + self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) + ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, + tmp_src, tmp_tgt) + if not ret_updated: + raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") diff --git a/omni/accelerators/pd/__init__.py b/omni/accelerators/pd/__init__.py index 755f74583..1eba5214b 100644 --- a/omni/accelerators/pd/__init__.py +++ b/omni/accelerators/pd/__init__.py @@ -11,7 +11,8 @@ def register(): KVConnectorFactory.register_connector( "AscendHcclConnectorV1", "omni.accelerators.pd.omni_cache_connector_v1" if os.getenv("ENABLE_OMNI_CACHE", "0") == "1" - else "omni.accelerators.pd.llmdatadist_connector_v1", + elif "omni.accelerators.pd.llmdatadist_connector_v2" if os.getenv("ENABLE_DYNAMIC_LLMDATADIST", "0") == "1" + else "omni.accelerators.pd.llmdatadist_connector", "LLMDataDistConnector" ) diff --git a/omni/accelerators/pd/llmdatadist_connector_v2.py b/omni/accelerators/pd/llmdatadist_connector_v2.py new file mode 100644 index 000000000..633d93837 --- /dev/null +++ b/omni/accelerators/pd/llmdatadist_connector_v2.py @@ -0,0 +1,967 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +import json +from collections.abc import Iterator +import math +import threading +from typing import TYPE_CHECKING, Any, Optional, Union, Mapping +import zmq +import os +import pickle +import time +import socket + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.logger import logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput + +from omni.accelerators.pd.utils import get_config_from_dict_or_env + +if TYPE_CHECKING: + from vllm.config import VllmConfig, KVTransferConfig + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request +from vllm.v1.request import Request +from vllm.utils import round_down +from dataclasses import dataclass +from collections import defaultdict +import torch +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group, get_pp_group) + +from vllm.utils import get_open_port +from vllm.v1.request import RequestStatus +import queue +from concurrent.futures import ThreadPoolExecutor +from omni.tools.profiler.trace.tracing import get_local_ip + +GET_META_MSG = b"get_meta_msg" + +thread_dump_path = os.environ.get("VLLM_THREAD_DUMP_PATH", "/tmp/vllm_thread_info") +BLOCK_RELEASE_DELAY = int(os.environ.get("BLOCK_RELEASE_DELAY", 600)) # seconds, use to free blocks when the request is finished for a long time +LLMDATADIST_BASE_PORT = int(os.environ.get("VLLM_LLMDATADIST_BASE_PORT", 15567)) + +from omni.accelerators.pd.llmdatadist_manager_v1 import LLMDataDistManager, LLMDataDistConfig +from omni.tools.profiler.apply_profiler_patches import patch_request +patch_request() + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_cluster_id: str + spec_token_ids: Optional[list[int]] + remote_dp_rank: Optional[int] + remote_request_id: Optional[str] + trace_headers: Optional[Mapping[str, str]] = None + +@dataclass +class ReqMetaPrefill: + finish_time: float + +class DatadistConnectorMetadata(KVConnectorMetadata): + """Metadata for datadist connector.""" + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + trace_headers: Optional[Mapping[str, str]] = None, + ): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_host=kv_transfer_params["remote_host_ip"], + remote_cluster_id=kv_transfer_params["remote_cluster_id"], + spec_token_ids=kv_transfer_params["spec_token_ids"], + remote_dp_rank=kv_transfer_params.get("remote_dp_rank", 0), + remote_request_id=kv_transfer_params.get("remote_request_id", None), + trace_headers=trace_headers or {}, + ) + +class DatadistConnectorMetadataPrefill(KVConnectorMetadata): + """Metadata for datadist connector.""" + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req( + self, + request_id: str, + finish_time: float, + ): + self.requests[request_id] = ReqMeta( + finish_time=finish_time + ) + + +class LLMDataDistConnector(KVConnectorBase_V1): + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + if vllm_config.kv_transfer_config is None: + raise RuntimeError("vllm_config.kv_transfer_config cannot be None") + + if vllm_config.model_config.is_deepseek_mla: + vllm_config.kv_transfer_config.kv_parallel_size = 1 + logger.info("Set kv_parallel_size to 1 when use deepseek mla model.") + + local_host_ip = get_local_ip() + local_host_port = LLMDATADIST_BASE_PORT + self.datadist_config = LLMDataDistConfig(vllm_config, local_host_ip, local_host_port, ignore_load_rank=True) + self.host_cluster_id = self.datadist_config.host_cluster_id + self.host_ip = local_host_ip + # Introduce the environment variable VLLM_LLMDATADIST_ZMQ_PORT to resolve ZMQ connection conflicts during + # multi-P deployments on the same machine. + # This variable should not be set separately unless specifically required for this scenario. + self.host_port = get_config_from_dict_or_env(vllm_config.kv_transfer_config, "kv_port", + "VLLM_LLMDATADIST_ZMQ_PORT", "5568", int) + dp_rank = vllm_config.parallel_config.data_parallel_rank + self.host_port += dp_rank + self.is_prefill = vllm_config.kv_transfer_config.kv_role == "kv_producer" + + if role == KVConnectorRole.SCHEDULER: + if self.is_prefill: + self.connector_scheduler = PrefillConnectorScheduler(vllm_config, self.host_cluster_id, self.host_ip, str(self.host_port), str(self.host_port + 1000)) + else: + self.connector_scheduler = DecodeConnectorScheduler(vllm_config, str(self.host_port + 2000)) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + if self.is_prefill: + self.connector_worker = PrefillConnectorWorker(vllm_config, str(self.host_ip), str(self.host_port), str(self.host_port + 1000)) + else: + self.connector_worker = DecodeConnectorWorker(vllm_config, str(self.host_ip), self.host_cluster_id, str(self.host_port + 2000)) + self.connector_scheduler = None + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + if self.connector_scheduler is None: + raise RuntimeError("self.connector_scheduler cannot be None") + return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + if self.connector_scheduler is None: + raise RuntimeError("self.connector_scheduler cannot be None") + return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + if self.connector_scheduler is None: + raise RuntimeError("self.connector_scheduler cannot be None") + return self.connector_scheduler.build_connector_metadata(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + spec_token_ids: Optional[list[int]] = [] + ) -> tuple[bool, Optional[dict[str, Any]]]: + if self.connector_scheduler is None: + raise RuntimeError("self.connector_scheduler cannot be None") + return self.connector_scheduler.request_finished(request, block_ids, spec_token_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + if self.connector_worker is None: + raise RuntimeError("self.connector_worker cannot be None") + return self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + if self.connector_worker is None: + raise RuntimeError("self.connector_worker cannot be None") + return self.connector_worker.get_finished(self._connector_metadata) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + if self.connector_worker is None: + raise RuntimeError("self.connector_worker cannot be None") + if not isinstance(self._connector_metadata, Union[DatadistConnectorMetadata, DatadistConnectorMetadataPrefill]): + raise RuntimeError("self._connector_metadata must be an instance of DatadistConnectorMetadata or DatadistConnectorMetadataPrefill") + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Connector does not do layerwise saving.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """Connector does not save explicitly.""" + pass + + def wait_for_save(self): + """Connector does not save explicitly.""" + pass + + def pop_trace_headers(self, req_id: str) -> dict: + if self.connector_scheduler is None: + raise RuntimeError("self.connector_scheduler cannot be None") + return self.connector_scheduler.pop_trace_headers(req_id) + +class PrefillConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config, host_cluster_id: str, host_ip: str, host_port: str, trace_p_port: str): + self.vllm_config = vllm_config + self.host_cluster_id = host_cluster_id + self.host_ip = host_ip + self.host_port = host_port + self.trace_p_port=trace_p_port + logger.info("Initializing LLMDataDist Scheduler %s %s %s", host_cluster_id, host_ip, host_port) + # initialize the dict to save requests finish time + self.requests_finish_time = dict() + # req_id -> headers + self.sending_trace_headers: dict[str, dict] = {} + + self._transfer_lock = threading.Lock() + self.ctx = zmq.Context() + self.pull_socket = self.ctx.socket(zmq.PULL) + self.pull_socket.bind(f"tcp://{host_ip}:{trace_p_port}") + if os.getenv("PROFILING_NAMELIST", None): + self._listener_thread = threading.Thread(target=self._listen_worker_headers, daemon=True) + self._listener_thread.start() + + def _listen_worker_headers(self): + while True: + try: + msg_str = self.pull_socket.recv_string() + msg_list = json.loads(msg_str) + for msg in msg_list: + req_id = msg['remote_request_id'] or msg['request_id'] + headers = msg.get('trace_headers', {}) + with self._transfer_lock: + self.sending_trace_headers[req_id] = headers + except Exception as e: + logger.error(f"Failed to receive worker header (P): {e}") + time.sleep(1) + + def pop_trace_headers(self, req_id: str) -> dict: + with self._transfer_lock: + trace_headers = self.sending_trace_headers.pop(req_id, {}) + if trace_headers: + return trace_headers + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + pass + + def build_connector_metadata( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + metadata = DatadistConnectorMetadataPrefill() + # add requests finish time to metadata, to pass to worker connector + metadata.requests = {req_id: ReqMetaPrefill(finish_time=finish_time) + for req_id, finish_time in self.requests_finish_time.items()} + self.requests_finish_time.clear() + return metadata + + def request_finished( + self, + request: "Request", + block_ids: list[int], + spec_token_ids: Optional[list[int]] = [] + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + return False, None + + delay_free_blocks = len(block_ids) > 0 + # record the finish time of the request + if delay_free_blocks: + self.requests_finish_time[request.request_id] = time.monotonic() + + return delay_free_blocks, dict( + remote_block_ids=block_ids, + remote_cluster_id=self.host_cluster_id, + remote_host_ip=f"tcp://{self.host_ip}:{self.host_port}", + spec_token_ids=spec_token_ids, + remote_dp_rank=self.vllm_config.parallel_config.data_parallel_rank, + remote_request_id=request.request_id + ) + + +class PrefillConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: "VllmConfig", host_ip: str, host_port: str, trace_p_port: str): + # Metadata. + self.host_ip = host_ip + self.host_port = host_port + self.trace_p_port = trace_p_port + self.rank = get_tensor_model_parallel_rank() + if self.rank == 0 and get_pp_group().is_last_rank: + self.ctx = zmq.Context() + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"tcp://{self.host_ip}:{self.host_port}") + logger.info(f"ConnectWorker bind tcp://{self.host_ip}:{self.host_port}") + self._transfer_lock = threading.Lock() + self.receive_req_list = [] + thread_name = "prefill_connector_get_pulled_kv_req_list" + self.thread = threading.Thread(target=self.get_pulled_kv_req_list, daemon=True, name=thread_name) + self.thread.start() + dump_thread_to_file(self.thread, thread_name, thread_dump_path) + + # check whether omni attention is enabled + manager_cls = LLMDataDistManager + if vllm_config.additional_config and "enable_omni_attn" in vllm_config.additional_config: + # do import only when necessary + from omni.accelerators.cache import OmniBiGroupDataDistManager, check_omni_attn_cmd_arg + use_omni_attn_mgr = check_omni_attn_cmd_arg(vllm_config.additional_config) + if use_omni_attn_mgr: + manager_cls = OmniBiGroupDataDistManager + logger.warning(f"PrefillingConnector is using Omni datadist manager for KV transfer.") + local_host_port = LLMDATADIST_BASE_PORT + self.datadist_manager = manager_cls(vllm_config, self.host_ip, local_host_port) + + # initialize the dict to save requests finish time + self.requests_finish_time = dict() + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + self.datadist_manager.register_memory(kv_caches) + + def start_load_kv(self, metadata: DatadistConnectorMetadataPrefill): + pass + + def get_finished(self, metadata: DatadistConnectorMetadataPrefill) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving. + """ + all_done_sending: set[str] = set() + all_done_recving: set[str] = set() + if self.rank == 0 and get_pp_group().is_last_rank: + # Update requests_finish_time with new finish times from metadata + with self._transfer_lock: + self.requests_finish_time.update( + {req_id: meta.finish_time for req_id, meta in metadata.requests.items()} + ) + current_time = time.monotonic() + # Identify requests whose finish time exceeds BLOCK_RELEASE_DELAY + out_date_reqs = [] + for req_id, finish_time in self.requests_finish_time.items(): + if current_time - finish_time > BLOCK_RELEASE_DELAY: + out_date_reqs.append(req_id) + else: + # Since the dict is ordered by finish_time, we can break early + break + for req_id in out_date_reqs: + logger.warning( + f"Request {req_id} is out of date, finish time: {self.requests_finish_time[req_id]}. Freeing blocks now." + ) + all_done_sending.add(req_id) + del self.requests_finish_time[req_id] + + if len(self.receive_req_list) == 0: + return all_done_sending, all_done_recving + + with self._transfer_lock: + for item in self.receive_req_list: + req_id = item.get('remote_request_id')#item['remote_request_id'] + headers = item.get('trace_headers', {}) + logger.debug(f"Get_finished: request {req_id}") + all_done_sending.add(req_id) + # if the request's kv has been received, remove it from requests_finish_time + if req_id in self.requests_finish_time: + del self.requests_finish_time[req_id] + self.receive_req_list.clear() + + return all_done_sending, all_done_recving + + def get_pulled_kv_req_list(self): + path_p = f"tcp://{self.host_ip}:{self.trace_p_port}" + socket_p = self.ctx.socket(zmq.PUSH) + socket_p.connect(path_p) + while True: + try: + if self.input_socket.poll(timeout=10) > 0: + message = self.input_socket.recv_string() + id_list = json.loads(message) # Parse the received JSON string into a list + logger.debug("Received: %s", id_list) + with self._transfer_lock: + self.receive_req_list.extend(id_list) + if os.getenv("PROFILING_NAMELIST", None): + json_data = json.dumps(id_list) + socket_p.send_string(json_data) + except Exception as e: + logger.error("get pulled kv req list failed: %s", e) + + +class DecodeConnectorScheduler: + """Implementation of Scheduler side methods""" + def __init__(self, vllm_config: VllmConfig, trace_d_port: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self.processed_request: set[str] = set() + self.ctx = zmq.Context() + self.zmq_socket_map = {} + + self.host_ip = get_local_ip() + self.trace_d_port = trace_d_port + self.recving_trace_headers: dict[str, dict] = {} + self._transfer_lock = threading.Lock() + self.ctx = zmq.Context() + self.pull_socket = self.ctx.socket(zmq.PULL) + self.pull_socket.bind(f"tcp://{self.host_ip}:{self.trace_d_port}") + + if os.getenv("PROFILING_NAMELIST", None): + self._listener_thread = threading.Thread(target=self._listen_worker_headers, daemon=True) + self._listener_thread.start() + + additional_config = vllm_config.additional_config + if additional_config: + self.async_pull_kv = additional_config.get("async_pull_kv", False) + else: + self.async_pull_kv = False + + if self.async_pull_kv: + self.context = zmq.Context() + self.pub = self.context.socket(zmq.PUB) + kv_rank = self.vllm_config.kv_transfer_config.kv_rank + self.pub.bind(f"ipc:///tmp/sched-pub-{kv_rank}-{vllm_config.parallel_config.data_parallel_rank_local}") + + def _listen_worker_headers(self): + while True: + try: + msg_str = self.pull_socket.recv_string() + msg_list = json.loads(msg_str) + for msg in msg_list: + req_id = msg['remote_request_id'] + headers = msg.get('trace_headers', {}) + with self._transfer_lock: + self.recving_trace_headers[req_id] = headers + except Exception as e: + logger.error(f"Failed to receive worker header (D): {e}") + time.sleep(1) + + def pop_trace_headers(self, req_id: str) -> dict: + with self._transfer_lock: + trace_headers = self.recving_trace_headers.pop(req_id, {}) + if trace_headers: + return trace_headers + + def _send_pulled_kv_req_list(self, path, data): + if path in self.zmq_socket_map: + socket = self.zmq_socket_map[path] + else: + socket = self.ctx.socket(zmq.PUSH) + socket.connect(path) + self.zmq_socket_map[path] = socket + logger.info(f"create new socket path:{path}") + + try: + json_data = json.dumps(data) + socket.send_string(json_data) + logger.info(f"send string {json_data} path:{path}") + except Exception as e: + logger.error(f"Failed to send reqest_id {json_data} to prefill: {e}") + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + if request.request_id in self.processed_request: + return 0, False + params = request.kv_transfer_params + if params is None: + return 0, False + logger.debug( + "DatadistConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, params) + + if num_computed_tokens % self.block_size != 0: + raise RuntimeError("num_computed_tokens must be divisible by self.block_size") + rounded_num_prompt_tokens = self._round_up( + len(request.prompt_token_ids), self.block_size) + count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) + return count, count > 0 + + def _round_up(self, x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + logger.debug(f"Request id {request.request_id}: blocks length is {len(blocks.blocks)}") + params = request.kv_transfer_params + logger.debug( + "DatadistConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + + self.processed_request.add(request.request_id) + if params is not None: + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_cluster_id", "remote_host_ip")): + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + logger.warning( + "Got invalid KVTransferParams: %s.", params) + + def build_connector_metadata( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + metadata = DatadistConnectorMetadata() + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + if req.kv_transfer_params is None: + logger.warning(f"For reuqest {req_id}: kv_transfer_params now is None") + else: + metadata.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + trace_headers=req.trace_headers or {}, + ) + req.kv_transfer_params = None + self._reqs_need_recv.clear() + + if self.async_pull_kv: + if scheduler_output is None: + # Let go fast path + if metadata.requests: + serialized_data = pickle.dumps(metadata) + self.pub.send(serialized_data) + + return metadata + + def request_finished( + self, + request: "Request", + block_ids: list[int], + spec_token_ids: Optional[list[int]] = [] + ) -> tuple[bool, Optional[dict[str, Any]]]: + if request.request_id in self.processed_request: + self.processed_request.remove(request.request_id) + if request.status == RequestStatus.FINISHED_ABORTED and request.kv_transfer_params is not None: + self._send_pulled_kv_req_list(request.kv_transfer_params.get("remote_host_ip"), [{'request_id': request.request_id, 'trace_headers': request.trace_headers or {}}]) + return False, None + + +class DecodeConnectorWorker: + """Worker implementation for datadist.""" + + def __init__(self, vllm_config: "VllmConfig", host_ip: str, host_cluster_id: int, trace_d_port: str): + self.vllm_config = vllm_config + self.host_cluster_id = host_cluster_id + self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.tp_rank = get_tensor_model_parallel_rank() + additional_config = vllm_config.additional_config + if additional_config: + self.async_pull_kv = additional_config.get("async_pull_kv", False) + self.multi_thread_pull_kv = additional_config.get("multi_thread_pull_kv", False) + self.multi_rank_pull_kv = additional_config.get("multi_rank_pull_kv", False) + else: + self.async_pull_kv = False + self.multi_thread_pull_kv = False + self.multi_rank_pull_kv = False + if self.multi_rank_pull_kv: + self.multi_thread_pull_kv = True + if vllm_config.parallel_config.tensor_parallel_size > 1 and self.multi_rank_pull_kv: + raise ValueError("multi_rank_pull_kv are not supported when tp > 1.") + + # check whether omni attention is enabled + manager_cls = LLMDataDistManager + if vllm_config.additional_config and "enable_omni_attn" in vllm_config.additional_config: + # do import only when necessary + from omni.accelerators.cache import OmniBiGroupDataDistManager, check_omni_attn_cmd_arg + use_omni_attn_mgr = check_omni_attn_cmd_arg(vllm_config.additional_config) + if use_omni_attn_mgr: + manager_cls = OmniBiGroupDataDistManager + logger.warning(f"DecodeConnector is using Omni datadist manager for KV transfer.") + self.datadist_manager = manager_cls(vllm_config, host_ip, 0) + + self._recving_transfers: list = [] + self._done_recving_count: defaultdict[str, int] = defaultdict(lambda: 0) + + self._pull_kv_lock = threading.Lock() + self.queues = {} # cluster_id -> queue.Queue + self.threads = {} # cluster_id -> threading.Thread + + self._transfer_lock = threading.Lock() + self.host_ip = host_ip + self.trace_d_port = trace_d_port + + self.ctx = zmq.Context() + self.zmq_socket_map = {} + + if self.async_pull_kv: + # dp_rank = vllm_config.parallel_config.data_parallel_rank_local + thread_name = f"async_pull_kv_{self.dp_rank}" + self.thread_on_fast_path_req = threading.Thread(target=self.on_fast_path_req, daemon=True, name=thread_name) + self.thread_on_fast_path_req.start() + logger.warning(f"DecodeConnectorWorker initialized with self.async_pull_kv enabled.") + + # Write thread name and native_id to file + dump_thread_to_file(self.thread_on_fast_path_req, thread_name, thread_dump_path) + + if self.multi_thread_pull_kv and self.vllm_config.parallel_config.tensor_parallel_size > 1: + self.tp_sync_path = f"ipc:///tmp/tp-sync-dp{self.vllm_config.parallel_config.data_parallel_rank}" + if get_tensor_model_parallel_rank() == 0: + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(self.tp_sync_path) + logger.info(f"ConnectWorker bind {self.tp_sync_path}") + + self.tp_sync_req_dict = {} + thread_name = f"decode_connector_sync_pulled_tp_kvcache_and_send_dp{self.vllm_config.parallel_config.data_parallel_rank}" + self.sync_thread = threading.Thread(target=self.sync_pulled_tp_kvcache_and_send, daemon=True, + name=thread_name) + self.sync_thread.start() + dump_thread_to_file(self.sync_thread, thread_name, thread_dump_path) + + def sync_pulled_tp_kvcache_and_send(self): + while True: + try: + if self.input_socket.poll(timeout=10) > 0: + data = self.input_socket.recv_json() + request_id = data[0].get("request_id") + remote_host_ip = data[0].get("remote_host_ip") + remote_request_id = data[0].get("remote_request_id", None) + trace_headers = data[0].get("trace_headers", {}) + # if request_id not in dict, set to 0, else do nothing + self.tp_sync_req_dict.setdefault(request_id, 0) + self.tp_sync_req_dict[request_id] += 1 + logger.debug(f"{request_id} finish pull kv {self.tp_sync_req_dict[request_id]} times.") + if self.tp_sync_req_dict[request_id] == self.vllm_config.parallel_config.tensor_parallel_size: + self.tp_sync_req_dict.pop(request_id) + self._send_pulled_kv_req_list(remote_host_ip, [{'remote_request_id': remote_request_id, 'trace_headers': trace_headers or {}}]) + with self._transfer_lock: + self._recving_transfers.append(request_id) + except Exception as e: + logger.error("Sync pulled kv when tp > 1 and send failed: %s", e) + + def on_fast_path_req(self): + context = zmq.Context() + sub = context.socket(zmq.SUB) + kv_rank = self.vllm_config.kv_transfer_config.kv_rank + sub.connect(f"ipc:///tmp/sched-pub-{kv_rank}-{self.vllm_config.parallel_config.data_parallel_rank_local}") + sub.setsockopt_string(zmq.SUBSCRIBE, "") + + while True: + serialized_data = sub.recv() + metadata = pickle.loads(serialized_data) + for req_id, meta in metadata.requests.items(): + if (len(meta.local_block_ids) > 0) and (len(meta.remote_block_ids) > 0): + self.start_load_kv(metadata) + if self.tp_rank == 0: + logger.info( + "Received fast path request for request %s with " + "local_block_ids: %s, remote_block_ids: %s.", + req_id, + len(meta.local_block_ids), + len(meta.remote_block_ids) + ) + + def worker(self, cluster_id): + q = self.queues[cluster_id] + time.sleep(0) + while True: + task = q.get() + if task is None: + continue + try: + self._read_blocks(**task) + except Exception as e: + logger.error("KV transfer task failed in thread %s: %s", cluster_id, e) + patch_data = [{'request_id': task['request_id'], 'trace_headers': task.get('trace_headers', {})}] + self._send_pulled_kv_req_list(task['remote_host_ip'], patch_data) + raise RuntimeError(f"Failed to pull kv for request:{task['request_id']} from cluster:{cluster_id}.") + q.task_done() + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + self.datadist_manager.register_memory(kv_caches) + # TODO:put multi-thread_pull_kv and multi_rank_pull_kv related registered_link_infos into queues + # In single thread pull kv mode, we use a single thread to pull kv + logger.info(" ***** Using single thread to pull kv.") + max_concurrents = 1 + self.executor = ThreadPoolExecutor(max_workers=max_concurrents) + + logger.debug("Finish register_kv_caches.") + + # Now go asynchronous pull_kv + def start_load_kv(self, metadata: DatadistConnectorMetadata): + logger.debug(f" ***** start_load_kv: {len(metadata.requests)}") + futures = [] + for req_id, meta in metadata.requests.items(): + # if the local_block_ids is empty, skip pulling kv for the request + if len(meta.local_block_ids) == 0: + if self.tp_rank == 0: + logger.info(f" ***** Request {req_id} has 0 local blocks, skip load kv.") + continue + # If local_block_ids is a flat list of int, omni-attention is not used + # and we can directly use the local_block_ids and remote_block_ids + if isinstance(meta.local_block_ids[0], int): + # local_block_ids (kv blocks in D) is more than remote_block_ids (kv blocks in P) + # leaded by lookahead num, which is used by eagle and multi step + if len(meta.remote_block_ids) < len(meta.local_block_ids): + meta.local_block_ids = meta.local_block_ids[:len(meta.remote_block_ids)] + logger.debug("look ahead token num is greater than 0") + # If remote_block_ids is more than local_block_ids, we only need the last N remote blocks + # where N is the number of local blocks + elif len(meta.remote_block_ids) > len(meta.local_block_ids): + meta.remote_block_ids = meta.remote_block_ids[-len(meta.local_block_ids):] + if self.tp_rank == 0: + logger.info( + " ***** start_load_kv for request %s " + "Num local_block_ids: %s. Num remote_block_ids: %s.", + req_id, + len(meta.local_block_ids), + len(meta.remote_block_ids) + ) + # If local_block_ids is a list of lists (e.g., [[], []]), omni-attention is used + # local_block_ids[0] is a list of local block ids for uncompressed layers + # local_block_ids[1] is a list of local block ids for compressed layers + elif isinstance(meta.local_block_ids[0], list): + # If local_block_ids[0] is a list of lists, we need to ensure that remote_block_ids + # is a list of lists as well, where each sublist corresponds to the local_block + meta.remote_block_ids = [meta.remote_block_ids] * len(meta.local_block_ids) + # If local_block_ids[0] is empty, skip pulling kv for the request + if len(meta.local_block_ids[0]) == 0: + if self.tp_rank == 0: + logger.info(f" ***** Request {req_id} has 0 local blocks, skip load kv.") + continue + # remote_block_ids in P is less than local_block_ids[0] in D, + # leaded by lookahead num, which is used by eagle and multi step + elif len(meta.remote_block_ids[0]) < len(meta.local_block_ids[0]): + meta.local_block_ids[0] = meta.local_block_ids[0][:len(meta.remote_block_ids[0])] + logger.debug("look ahead token num is greater than 0") + # If remote_block_ids in P is more than local_block_ids[0] in D, we only need the last N remote blocks + elif len(meta.remote_block_ids[0]) > len(meta.local_block_ids[0]): + meta.remote_block_ids[0] = meta.remote_block_ids[0][-len(meta.local_block_ids[0]):] + if self.tp_rank == 0: + logger.info( + " ***** start_load_kv for request %s " + "Num local_block_ids: %s. Num remote_block_ids: %s.", + req_id, + len(meta.local_block_ids[0]), + len(meta.remote_block_ids[0]) + ) + # handle the unexpected case where local_block_ids is not a list of int or list of lists + else: + logger.error(f"Unexpected type for meta.local_block_ids[0]: {type(meta.local_block_ids[0])}") + raise RuntimeError(f"Unexpected type for meta.local_block_ids[0]: {type(meta.local_block_ids[0])}") + cluster_ids = self.datadist_manager.get_real_remote_cluster_ids(meta) + if self.multi_rank_pull_kv: + # If multi_rank_pull_kv is enabled, each DP rank will pull kv from multiple P ranks + # and the cluster_ids are obtained from registered_link_infos + # If the local_block_ids is a flat list of int, we can directly use it + # As multi_rank_pull_kv is designed to pull kv from two P ranks, + # we split the local_block_ids and remote_block_ids into two parts + if not isinstance(meta.local_block_ids[0], list): + block_thre = len(meta.local_block_ids) // 2 + # If the local_block_ids is a flat list of list, only split the blocks for uncompressed layers + else: + block_thre = len(meta.local_block_ids[0]) // 2 + for idx_cluster, cluster_id in enumerate(cluster_ids): + if not isinstance(meta.local_block_ids[0], list): + if idx_cluster == 0: + local_blocks = meta.local_block_ids[:block_thre] + remote_blocks = meta.remote_block_ids[:block_thre] + len_local_blocks = len(local_blocks) + else: + local_blocks = meta.local_block_ids[block_thre:] + remote_blocks = meta.remote_block_ids[block_thre:] + len_local_blocks = len(local_blocks) + else: + if idx_cluster == 0: + # For uncompressed layers, split the local_block_ids[0] and remote_block_ids + # For compressed layers, only pull kv from the second P rank + local_blocks = [meta.local_block_ids[0][:block_thre], []] + # remote_blocks need to be split as well for getting kv blocks for compressed layers in P + remote_blocks = [meta.remote_block_ids[0][:block_thre], []] + len_local_blocks = len(local_blocks[0]) + else: + local_blocks = [meta.local_block_ids[0][block_thre:], meta.local_block_ids[1]] + remote_blocks = [meta.remote_block_ids[0][block_thre:], meta.remote_block_ids[1]] + len_local_blocks = len(local_blocks[0]) + if len_local_blocks > 0: + task = { + 'request_id': req_id, + 'remote_request_id': meta.remote_request_id, + 'dst_cluster_id': cluster_id, + 'local_block_ids': local_blocks, + 'remote_block_ids': remote_blocks, + 'remote_host_ip': meta.remote_host, + 'prefill_dp_rank': meta.remote_dp_rank, + 'trace_headers': meta.trace_headers or {}, + } + logger.warning(f"*********** dst cluster_id is {cluster_id}.") + self.queues[cluster_id].put(task) + elif self.multi_thread_pull_kv: + task = { + 'request_id': req_id, + 'remote_request_id': meta.remote_request_id, + 'dst_cluster_id': cluster_ids[0], + 'local_block_ids': meta.local_block_ids, + 'remote_block_ids': meta.remote_block_ids, + 'remote_host_ip': meta.remote_host, + 'prefill_dp_rank': meta.remote_dp_rank, + 'trace_headers': meta.trace_headers or {}, + } + + self.queues[cluster_ids[0]].put(task) + else: + # Use ThreadPoolExecutor to handle the task + future = self.executor.submit( + self._read_blocks, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + dst_cluster_id=cluster_ids[0], + request_id=req_id, + remote_request_id=meta.remote_request_id, + remote_host_ip=meta.remote_host, + prefill_dp_rank=meta.remote_dp_rank, + trace_headers=meta.trace_headers, + ) + futures.append(future) + + if not self.multi_thread_pull_kv: + for future in futures: + future.add_done_callback(handle_exception) + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + dst_cluster_id: str, + request_id: str, + remote_request_id: str, + remote_host_ip: str, + prefill_dp_rank: int, + trace_headers: Optional[Mapping[str, str]] = None + ): + start = time.time() + self.datadist_manager.pull_kv(remote_block_ids, local_block_ids, dst_cluster_id, prefill_dp_rank) + + if self.vllm_config.parallel_config.tensor_parallel_size == 1: + # tp=1, send to prefill tp rank0 directly. + self._send_pulled_kv_req_list(remote_host_ip, [{'remote_request_id': remote_request_id, 'trace_headers': trace_headers or {}}]) + with self._transfer_lock: + self._recving_transfers.append(request_id) + else: + if self.multi_thread_pull_kv: + # tp>1, send to decode to rank0 firstly. + self._send_pulled_kv_req_list( + self.tp_sync_path, + { + "request_id": request_id, + "remote_request_id": remote_request_id, + "remote_host_ip": remote_host_ip, + 'trace_headers': trace_headers or {} + } + ) + else: + torch.distributed.barrier(group=get_tp_group().cpu_group) + if get_tensor_model_parallel_rank() == 0: + self._send_pulled_kv_req_list(remote_host_ip, [{'remote_request_id': remote_request_id, 'trace_headers': trace_headers or {}}]) + with self._transfer_lock: + self._recving_transfers.append(request_id) + logger.debug(f" ***** read block, req_id:{request_id}, local_block_ids:{local_block_ids}, remote_block_ids:{remote_block_ids}") + cost = time.time() - start + if self.tp_rank == 0: + logger.info(f" ***** read block, req_id:{request_id}, cost:{cost:.6f}") + + + def _send_pulled_kv_req_list(self, path, data): + if path in self.zmq_socket_map: + socket = self.zmq_socket_map[path] + else: + socket = self.ctx.socket(zmq.PUSH) + socket.connect(path) + self.zmq_socket_map[path] = socket + logger.info(f"create new socket path:{path}") + + path_d = f"tcp://{self.host_ip}:{self.trace_d_port}" + socket_d = self.ctx.socket(zmq.PUSH) + socket_d.connect(path_d) + + try: + json_data = json.dumps(data) + socket.send_string(json_data) + logger.info(f"send string {json_data} path:{path}") + if os.getenv("PROFILING_NAMELIST", None): + socket_d.send_string(json_data) + except Exception as e: + logger.error(f"Failed to send reqest_id {json_data} to prefill: {e}") + + def get_finished(self, metadata: DatadistConnectorMetadata) -> tuple[set[str], set[str]]: + # for decode size, done_sending is no need + all_done_sending: set[str] = set() + with self._transfer_lock: + all_done_recving = self._pop_done_transfers(self._recving_transfers) + if len(all_done_recving) > 0: + logger.debug( + "Get_finished: %s requests done recving", len(all_done_recving)) + + return all_done_sending, all_done_recving + + def _pop_done_transfers(self, transfers: list) -> set[str]: + done_req_ids: set[str] = set() + for req_id in transfers: + done_req_ids.add(req_id) + self._recving_transfers.clear() + return done_req_ids + +def handle_exception(future): + if future.exception(): + logger.error(f"Exception occurred in future: {future.exception()}") + raise future.exception() + +def dump_thread_to_file(thread, thread_name: str, folder_path: str): + + timeout = 5 # seconds + start_time = time.time() + while not hasattr(thread, "native_id"): + if time.time() - start_time > timeout: + logger.error(f"Timeout waiting for thread {thread_name} to have native_id.") + return + time.sleep(0.005) + + # Ensure the folder exists + if not os.path.exists(folder_path): + try: + os.makedirs(folder_path, exist_ok=True) + except Exception as e: + logger.error(f"Failed to create folder {folder_path}: {e}") + return + + file_path = os.path.join(folder_path, thread_name) + try: + with open(file_path, "w", encoding="utf-8") as f: + f.write(str(thread.native_id)) + except Exception as e: + logger.error(f"Failed to write thread info to {file_path}: {e}") diff --git a/omni/accelerators/pd/llmdatadist_manager_v1.py b/omni/accelerators/pd/llmdatadist_manager_v1.py new file mode 100644 index 000000000..a96cde71d --- /dev/null +++ b/omni/accelerators/pd/llmdatadist_manager_v1.py @@ -0,0 +1,468 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +import json +import time +from collections import defaultdict, namedtuple +from functools import cached_property + +import llm_datadist +import torch +from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, + LLMDataDist, LLMRole, RegisterMemStatus, LLMException, LLMStatusCode, + Placement, LLMClusterInfo, DataType) + +from vllm.config import KVTransferConfig +from vllm.distributed import get_world_group +from vllm.logger import init_logger +from vllm.model_executor.models.utils import extract_layer_index +from omni.accelerators.pd.ranktable.local_info import LocalInfo +from omni.accelerators.pd.ranktable.rank_table import GlobalRankTable, RankTableConfig +from omni.accelerators.pd.utils import get_p_start_rank, prepare_ranktables +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import (get_tp_group, get_dp_group, get_world_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +import os +import socket +import struct + +logger = init_logger(__name__) + +_ROLE_STR_TO_ENUM = { + "kv_producer": LLMRole.PROMPT, + "kv_consumer": LLMRole.DECODER +} + +TORCH_DTYPE_TO_NPU_DTYPE = { + torch.half: llm_datadist.DataType.DT_FLOAT16, + torch.float16: llm_datadist.DataType.DT_FLOAT16, + torch.bfloat16: llm_datadist.DataType.DT_BF16, + torch.float: llm_datadist.DataType.DT_FLOAT, + torch.float32: llm_datadist.DataType.DT_FLOAT, + torch.int8: llm_datadist.DataType.DT_INT8, + torch.int64: llm_datadist.DataType.DT_INT64, + torch.int32: llm_datadist.DataType.DT_INT32 +} + +SCHEDULER_LINK_BATCH_SIZE = 32 +SCHEDULER_LINK_INTERVAL = 0.5 +KV_CACHE_RETRY_TIMES = 1 +KV_CACHE_RETRY_WAIT_SECOND = 1 +SYNC_KV_TIMEOUT = 5000 # ms +LINK_TIMEOUT = 5000 # ms + +RETRYABLE_CODES = [ + LLMStatusCode.LLM_REPEAT_REQUEST, + LLMStatusCode.LLM_CLUSTER_NUM_EXCEED_LIMIT, + LLMStatusCode.LLM_PROCESSING_LINK, # Building chain is in progress + LLMStatusCode.LLM_DEVICE_OUT_OF_MEMORY, + LLMStatusCode.LLM_TIMEOUT, + LLMStatusCode.LLM_WAIT_PROCESS_TIMEOUT, + LLMStatusCode.LLM_LINK_BUSY, +] + +NUM_DIE_PER_MACH = int(os.getenv("NUM_DIE_PER_MACH", "16")) + +class LLMDataDistConfig: + """ + Configuration for the separate deployment. + """ + def __init__(self, vllm_config: VllmConfig, local_host_ip, host_port, ignore_load_rank=False) -> None: + additional_config = vllm_config.additional_config + if additional_config: + self.multi_rank_pull_kv = additional_config.get("multi_rank_pull_kv", False) + else: + self.multi_rank_pull_kv = False + self.local_host_ip = local_host_ip + self.host_port = host_port + self.kv_transfer_config = vllm_config.kv_transfer_config + self.kv_role_tmp = self.kv_transfer_config.kv_role + + self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.tp_rank = 0 + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.dp_size = vllm_config.parallel_config.data_parallel_size + + if ignore_load_rank: + self.rank = -1 + self.local_rank = -1 + self.cluster_id = -1 + else: + self.rank = get_world_group().rank_in_group + self.local_rank = get_world_group().local_rank + self.cluster_id = ip_port_to_int(f"{self.local_host_ip}:{int(self.host_port)+self.local_rank}", self.tp_size) + + # will be used in d side to checkout which P rank is selected to build kv link + self.kv_parallel_size = self.kv_transfer_config.kv_parallel_size + self.kv_producer_dp_size = self.kv_transfer_config.kv_connector_extra_config.get("kv_producer_dp_size", 1) + + host_ip_list = self._get_worker_ips() + self.host_ip_list = host_ip_list + + timestamp_ms = round(time.monotonic() * 1_000) + # host_cluster_id is a list, in order to handle the case that multi-node for one TP group + ip_integers = [ + ip_port_to_int(f"{ip}:{host_port}", self.tp_size) + for ip in host_ip_list + ] + + # (timestamp_ms, ip1_int, ip2_int, ip3_int, ...) + self.host_cluster_id = (timestamp_ms, *ip_integers) + + # get all node ips in a TP group + def _get_worker_ips(self): + """Return worker IPs. Only query Ray when Ray is actually available/running. + + Behavior: + - If self.is_prefill is False: return [self.local_host_ip]. + - If Ray is not installed: log and return [self.local_host_ip]. + - If Ray is installed but no cluster is reachable: log and return [self.local_host_ip]. + - If a Ray cluster is reachable: return all Alive nodes' NodeManagerAddress, + with head node (if detected) placed first. + """ + # default fallback + worker_ips = [self.local_host_ip] + + if not self.is_prefill: + return worker_ips + + try: + import ray + except ImportError: + logger.debug("Ray is not installed; skipping Ray cluster discovery.") + return worker_ips + + try: + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + nodes = ray.nodes() + except Exception as e: + logger.warning(f"Failed to connect/list Ray nodes (address='auto'): {e}. Using local_host_ip.") + return worker_ips + + ips = [] + head_ip = None + + for node in nodes: + if node.get("Alive"): + addr = node.get("NodeManagerAddress") + if addr: + ips.append(addr) + gcs_addr = node.get("GcsAddress", "") + if addr in gcs_addr: + head_ip = addr + else: + logger.error("Detected dead node in the Ray cluster. Please check machines' health.") + + if not ips: + return worker_ips + + if head_ip and head_ip in ips: + ips.remove(head_ip) + worker_ips = [head_ip] + ips + else: + worker_ips = ips + + return worker_ips + + @cached_property + def role(self): + return _ROLE_STR_TO_ENUM[self.kv_transfer_config.kv_role] + + @cached_property + def is_prefill(self): + return self.role == LLMRole.PROMPT + + +class LLMDataDistManager: + def __init__(self, vllm_config: VllmConfig, local_host_ip, host_port): + additional_config = vllm_config.additional_config + if additional_config: # pragma: no cover + self.multi_rank_pull_kv = additional_config.get("multi_rank_pull_kv", False) + else: # pragma: no cover + self.multi_rank_pull_kv = False + self.kv_transfer_config = vllm_config.kv_transfer_config + self.data_dist_config = LLMDataDistConfig(vllm_config, local_host_ip, host_port) + self.rank = self.data_dist_config.rank + self.local_rank = self.data_dist_config.local_rank + self.tp_size = self.data_dist_config.tp_size + self.tp_rank = self.data_dist_config.tp_rank + self.dp_size = self.data_dist_config.dp_size + self.dp_rank = self.data_dist_config.dp_rank + self.prefill_dp_size = self.data_dist_config.kv_producer_dp_size + if not self.data_dist_config.is_prefill: + self.decode_id = self.dp_rank // NUM_DIE_PER_MACH + + self.data_dist_engine = self._init_llm_data_dist() + + self.registered_kv_caches = [] + self.rank_link_info_map = {} + # the look-up table for pull kv, managed by each dp process + # { key: (host_cluster_id, prefill_dp_rank, d_rank), value:[prompt_cluster_id_list] } + self.registered_link_infos = {} + + def get_real_remote_cluster_ids(self, meta, tp_rank=0): + # remote_cluster_id: (timestamp, ip1, ip2, ...) + remote_id_key = tuple(meta.remote_cluster_id) if isinstance(meta.remote_cluster_id, list) else meta.remote_cluster_id + + key = (remote_id_key, meta.remote_dp_rank, self.rank) + remote_cluster_ids = self.registered_link_infos.get(key, None) + + if remote_cluster_ids is None: + old_key = None + for (reg_key, reg_dp_rank, reg_rank) in list(self.registered_link_infos.keys()): + if (reg_dp_rank == meta.remote_dp_rank and reg_rank == self.rank and + any(ip in reg_key[1:] for ip in remote_id_key[1:])): + old_key = (reg_key, reg_dp_rank, reg_rank) # reg_key: (time_stamp, ip1_int, .., ip2_int) + break + if old_key: + self.close_link(old_key[0], meta.remote_dp_rank, self.rank, tp_rank) + logger.warning(f"Deleted old link with {old_key}") + logger.warning(f"Could not find remote cluster id from {meta.remote_cluster_id=}, {meta.remote_dp_rank=}.") + logger.warning(f"Try to build new link with {meta.remote_cluster_id=}, {meta.remote_dp_rank=}...") + # Ensure register_link also receives hashable data + self.register_link(remote_id_key, meta.remote_dp_rank, self.rank, tp_rank) + remote_cluster_ids = self.registered_link_infos.get(key, None) + + return remote_cluster_ids + + def _init_llm_data_dist(self): + data_dist = LLMDataDist(self.data_dist_config.role, self.data_dist_config.cluster_id) + llm_config = LLMConfig() + llm_config.device_id = self.local_rank + llm_config.local_comm_res = "" + # RoCE timeout is SYNC_KV_TIMEOUT ms, prevent pull kv timeout + llm_config.sync_kv_timeout = SYNC_KV_TIMEOUT + llm_config.enable_remote_cache_accessible = True + + # do new_datadist_link + llm_config.local_comm_res = "" + # if is prefill, need to listen on specific ip and port to accept decode side connection + if self.data_dist_config.is_prefill: + host_ip_t = self.data_dist_config.local_host_ip + host_port_t = int(self.data_dist_config.host_port) + int(self.data_dist_config.local_rank) + llm_config.listen_ip_info = f"{host_ip_t}:{host_port_t}" + + options = llm_config.generate_options() + data_dist.init(options) + logger.info(f"init {self.data_dist_config.kv_role_tmp} success, {self.data_dist_config.cluster_id=}") + + return data_dist + + # dynamically register link only when is needed + def register_link(self, host_cluster_id, prefill_dp_rank, d_rank, tp_rank=0): + prompt_cluster_id_list = self._get_cluster_id_list(host_cluster_id[1:], prefill_dp_rank, d_rank, tp_rank) + clusters = [] + for PROMPT_CLUSTER_ID in prompt_cluster_id_list: + cluster = LLMClusterInfo() + host_ip, tp_size, tp_rank = cluster_id_to_ip_port(PROMPT_CLUSTER_ID) + remote_host_ip, port = host_ip.split(':') + cluster.remote_cluster_id = PROMPT_CLUSTER_ID + cluster.append_local_ip_info(self._get_local_ip(), 0) + cluster.append_remote_ip_info(remote_host_ip, int(port)) + clusters.append(cluster) + ret, _ = self.data_dist_engine.link_clusters(clusters, timeout=LINK_TIMEOUT) + if ret != LLMStatusCode.LLM_SUCCESS: + raise Exception("link failed") + # add the cluster_id to the dict + if not self.data_dist_config.is_prefill: + self.registered_link_infos[(host_cluster_id, prefill_dp_rank, d_rank)] = prompt_cluster_id_list + logger.info(f"rank:{self.rank} linked to : {remote_host_ip}, {prompt_cluster_id_list=}") + + # close the link when it is confirmed to be broken + def close_link(self, host_cluster_id, prefill_dp_rank, d_rank, tp_rank=0): + if not self.data_dist_config.is_prefill: + prompt_cluster_id_list = self._get_cluster_id_list(host_cluster_id[1:], prefill_dp_rank, d_rank, tp_rank) + else: + prompt_cluster_id_list = [host_cluster_id] + clusters = [] + for PROMPT_CLUSTER_ID in prompt_cluster_id_list: + cluster = LLMClusterInfo() + host_ip, tp_size, tp_rank = cluster_id_to_ip_port(PROMPT_CLUSTER_ID) + remote_host_ip, port = host_ip.split(':') + cluster.remote_cluster_id = PROMPT_CLUSTER_ID + cluster.append_local_ip_info(self._get_local_ip(), 0) + cluster.append_remote_ip_info(remote_host_ip, int(port)) + clusters.append(cluster) + ret, _ = self.data_dist_engine.unlink_clusters(clusters, timeout=LINK_TIMEOUT, force=True) + if ret != LLMStatusCode.LLM_SUCCESS: + raise Exception("unlink failed") + # remove the cluster_id from the dict + if not self.data_dist_config.is_prefill: + self.registered_link_infos.pop((host_cluster_id, prefill_dp_rank, d_rank), None) + logger.info(f"rank:{self.rank} unlinked with : {remote_host_ip}, {prompt_cluster_id_list=}") + + def _pull_blocks(self, src_cache_key, dst_cache, src_blocks, dst_blocks): + """" pull kv from remote cache to local cache, support return error state if pull kv fails """ + pass + + def pull_kv(self, src_blocks, tgt_blocks, prompt_cluster_id, prefill_dp_rank): + """ pull kv from remote cache to local cache, support to refresh link when pull kv fails """ + pass + + def _refresh_link(self, prompt_cluster_id, prefill_dp_rank, d_rank): + """ refresh the kv link: unlink + link """ + pass + + # search for the host_cluster_id in key using the prompt_cluster_id in value + def _get_host_cluster_id(self, prompt_cluster_id, prefill_dp_rank, d_rank): + """ search for the host_cluster_id in key using the prompt_cluster_id in value """ + pass + + def _get_cluster_id_list(self, host_cluster_ids, prefill_dp_rank, d_rank, tp_rank): + """ compute the cluster id that should be linked with the target dp rank """ + pass + + def _get_local_ip(self): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + return ip + + # reuse the existing code + def register_memory(self, kv_caches: dict[str, torch.Tensor]): + if len(self.registered_kv_caches) > 0: + raise ValueError("Attr `registered_kv_caches` must be empty before register kv_caches.") + if isinstance(kv_caches, dict): + flatten_kv_caches = unzip_kv_cache_dict(kv_caches) + else: + flatten_kv_caches = unzip_kv_cache_list(kv_caches) + + # dense model. + flatten_kv_caches = maybe_merge_kv_caches(flatten_kv_caches) + # spec model. + flatten_kv_caches = maybe_split_kv_caches_for_spec_layers(flatten_kv_caches) + + for model_id, sub_kv_caches in enumerate(flatten_kv_caches): + cache_desc = CacheDesc(num_tensors=len(sub_kv_caches), shape=tuple(sub_kv_caches[0].shape), + data_type=TORCH_DTYPE_TO_NPU_DTYPE[sub_kv_caches[0].dtype]) + + cache_addrs = [int(item.data_ptr()) for item in sub_kv_caches] + + if self.data_dist_config.is_prefill: + cache_key = BlocksCacheKey(self.data_dist_engine.cluster_id, model_id=model_id) + else: + cache_key = None + + cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, cache_key) + self.registered_kv_caches.append(cache) + logger.debug(f" ***** registered_kv_caches num:{len(self.registered_kv_caches)}") + +# reuse the existing code +def unzip_kv_cache_dict(kv_caches: dict[str, torch.Tensor], ): + # Convert kv_caches dict to a list of tensors in the order of layer_index. + _, first_kv_cache = next(iter(kv_caches.items())) + if isinstance(first_kv_cache, tuple): + cache_num = len(first_kv_cache) + else: + cache_num = 1 + + flatten_kv_caches = [[] for _ in range(cache_num)] + + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + layer_name = layer_names[0] + kv_cache = kv_caches[layer_name] + if isinstance(kv_cache, tuple): + for index, sub_cache in enumerate(kv_cache): + flatten_kv_caches[index].append(sub_cache) + else: + flatten_kv_caches[0].append(kv_cache) + return flatten_kv_caches + +# reuse the existing code +def unzip_kv_cache_list(kv_caches: list[torch.Tensor], ): + first_kv_cache = kv_caches[0] + if isinstance(first_kv_cache, tuple): + cache_num = len(first_kv_cache) + else: + cache_num = 1 + + flatten_kv_caches = [[] for _ in range(cache_num)] + + for kv_cache in kv_caches: + if isinstance(kv_cache, tuple): + for index, sub_cache in enumerate(kv_cache): + flatten_kv_caches[index].append(sub_cache) + else: + flatten_kv_caches[0].append(kv_cache) + return flatten_kv_caches + +# reuse the existing code +def maybe_merge_kv_caches(flatten_kv_caches): + # only 1 kvcache tensor with shape (2, b, s, n, d) + if len(flatten_kv_caches) == 1 and len(flatten_kv_caches[0][0].shape) == 5 and flatten_kv_caches[0][0].shape[0] == 2: + merged_kv_caches = [[]] + for sub_kv_caches in flatten_kv_caches[0]: + merged_kv_caches[0].append(sub_kv_caches[0]) + merged_kv_caches[1].append(sub_kv_caches[1]) + return merged_kv_caches + return flatten_kv_caches + +# reuse the existing code +def maybe_split_kv_caches_for_spec_layers(flatten_kv_caches): + flatten_kv_caches_split = [] + need_split = False + for caches in flatten_kv_caches: + shape_dict = {} + for cache in caches: + if str(cache.shape) not in shape_dict: + shape_dict[str(cache.shape)] = [] + shape_dict[str(cache.shape)].append(cache) + + flatten_kv_caches_split.extend(shape_dict.values()) + if len(shape_dict) > 1 or need_split: + need_split = True + + if not need_split: + return flatten_kv_caches + else: + return flatten_kv_caches_split + +def ip_port_to_int(ip_port, tp_size, tp_rank=0): + """ convert ip_port to int64 cluster id + + layout: + [ ip (32 bits) | port (16 bits) | tp_size (16 bits) ] + """ + ip, port_str = ip_port.split(':') + port = int(port_str) + if not (0 <= port <= 65535): + raise ValueError(" port must be in 0-65535 ") + # convert IP to 4 byte boolean + ip_bytes = socket.inet_aton(ip) + # convert 4 byte IP to 32 bit int + ip_int = struct.unpack('!I', ip_bytes)[0] + # now we only contain ip, port, tp_size, tp_rank is ignored for simplification + # result = (ip_int << 48) | (port << 32) | (tp_size << 16) | (tp_rank) + result = (ip_int << 32) | (port << 16) | (tp_size & 0xFFFF) + return result + + + +def cluster_id_to_ip_port(cluster_id): + """Extract ip_port from int64 cluster id (inverse of ip_port_to_int).""" + if not isinstance(cluster_id, int): + raise TypeError("cluster_id must be int type") + + # Extract fields (reverse order of packing) + tp_size = cluster_id & 0xFFFF # Lower 16 bits + port = (cluster_id >> 16) & 0xFFFF # Next 16 bits + ip_int = (cluster_id >> 32) & 0xFFFFFFFF # Upper 32 bits + + ip = socket.inet_ntoa(struct.pack('!I', ip_int)) + + return f"{ip}:{port}", tp_size, 0 # tp_rank always 0 \ No newline at end of file -- Gitee From 923818bc220c289f8f6c65ea70dafc8f6dfe95a6 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 25 Nov 2025 10:59:17 +0800 Subject: [PATCH 02/15] fix a typo --- omni/accelerators/cache/pd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omni/accelerators/cache/pd.py b/omni/accelerators/cache/pd.py index ebbcdc612..aba35c75c 100644 --- a/omni/accelerators/cache/pd.py +++ b/omni/accelerators/cache/pd.py @@ -2,7 +2,7 @@ from typing_extensions import override import torch from llm_datadist.v2.llm_types import Cache, CacheDesc, BlocksCacheKey import os -USE_NEW_DATADIST = os.getenv("ENABLE_DYNAMIC_DATADIST", "0") == "1" +USE_NEW_DATADIST = os.getenv("ENABLE_DYNAMIC_LLMDATADIST", "0") == "1" if USE_NEW_DATADIST: from omni.accelerators.pd.llmdatadist_manager_v1 import ( LLMDataDistManager, -- Gitee From c0e87eb2a12caeaaeba25b8048fa924151333ccd Mon Sep 17 00:00:00 2001 From: d00954433 Date: Tue, 25 Nov 2025 11:15:49 +0800 Subject: [PATCH 03/15] add the partial function implementation of llmdatadist_manager_v1 --- .../accelerators/pd/llmdatadist_manager_v1.py | 91 +++++++++++++++++-- 1 file changed, 85 insertions(+), 6 deletions(-) diff --git a/omni/accelerators/pd/llmdatadist_manager_v1.py b/omni/accelerators/pd/llmdatadist_manager_v1.py index a96cde71d..a987dc1a3 100644 --- a/omni/accelerators/pd/llmdatadist_manager_v1.py +++ b/omni/accelerators/pd/llmdatadist_manager_v1.py @@ -295,24 +295,103 @@ class LLMDataDistManager: def _pull_blocks(self, src_cache_key, dst_cache, src_blocks, dst_blocks): """" pull kv from remote cache to local cache, support return error state if pull kv fails """ - pass + for attempt in range(KV_CACHE_RETRY_TIMES): + try: + self.data_dist_engine.cache_manager.pull_blocks( + src_cache_key, dst_cache, src_blocks, dst_blocks + ) + return True + except LLMException as e: + code = e.status_code + if code in RETRYABLE_CODES: + logger.info( + f"kv cache pull blocks failed, need retry" + f"(attempt {attempt + 1}/{KV_CACHE_RETRY_TIMES}): {e}" + ) + if attempt < KV_CACHE_RETRY_TIMES - 1: + time.sleep(KV_CACHE_RETRY_WAIT_SECOND) + continue + logger.error( + f"kv cache pull blocks failed after {KV_CACHE_RETRY_TIMES} attempts: {e}" + ) + return False + else: + logger.error(f"kv cache pull blocks failed (non-retryable): {e}") + return False + except (TypeError, ValueError) as e: + logger.error(f"kv cache pull blocks input error: {e}") + return False + logger.error("kv cache pull blocks exhausted attempts without success") + return False def pull_kv(self, src_blocks, tgt_blocks, prompt_cluster_id, prefill_dp_rank): """ pull kv from remote cache to local cache, support to refresh link when pull kv fails """ - pass + torch.npu.set_device(f"npu:{self.local_rank}") + for model_id, kv_cache in enumerate(self.registered_kv_caches): + prompt_cache_key = BlocksCacheKey( + prompt_cluster_id=prompt_cluster_id, model_id=model_id) + ret = self._pull_blocks(prompt_cache_key, kv_cache, + src_blocks, tgt_blocks) + if not ret: + logger.warning(f"======= failed pull kv with {prompt_cluster_id=} ========") + self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) + logger.warning(f"======= successfully rebuild kv link with {prompt_cluster_id=} ========") + ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, + src_blocks, tgt_blocks) + if not ret_updated: + raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") def _refresh_link(self, prompt_cluster_id, prefill_dp_rank, d_rank): """ refresh the kv link: unlink + link """ - pass + logger.warning(f"======= refresh_link with {prompt_cluster_id=} ========") + (host_cluster_id, prefill_dp_rank, d_rank) = \ + self._get_host_cluster_id(prompt_cluster_id, prefill_dp_rank, d_rank) + if host_cluster_id is not None: + self.close_link(host_cluster_id, prefill_dp_rank, d_rank) + logger.warning(f"======= rebuild_link with {prompt_cluster_id=} ========") + self.register_link(host_cluster_id, prefill_dp_rank, d_rank) + else: + raise RuntimeError(f"Unregistered host cluster id!!!") # search for the host_cluster_id in key using the prompt_cluster_id in value def _get_host_cluster_id(self, prompt_cluster_id, prefill_dp_rank, d_rank): """ search for the host_cluster_id in key using the prompt_cluster_id in value """ - pass - + prompt_p_metas = [ + key for key, values in self.registered_link_infos.items() + if (isinstance(values, list) and + prompt_cluster_id in values and + len(key) >= 3 and + key[1] == prefill_dp_rank and + key[2] == d_rank) + ] + if not prompt_p_metas: + return None + else: + return prompt_p_metas[0] + def _get_cluster_id_list(self, host_cluster_ids, prefill_dp_rank, d_rank, tp_rank): """ compute the cluster id that should be linked with the target dp rank """ - pass + if isinstance(host_cluster_ids, int): + host_cluster_ids = [host_cluster_ids] + ip_ports = [] + for host_cluster_id in host_cluster_ids: + ip_port, prefill_tp_size, _ = cluster_id_to_ip_port(host_cluster_id) + ip_ports.append(ip_port) + decode_tp_size = self.data_dist_config.kv_parallel_size + decode_id = 0 + decode_num = int(os.getenv('DECODE_POD_NUM', "1")) + + p_rank_start = get_p_start_rank(prefill_tp_size, 1, decode_tp_size, self.dp_size, + decode_num, decode_id, d_rank) + p_rank_list = [p_rank_start + dp_idx * prefill_tp_size for dp_idx in range(self.prefill_dp_size)] + cluster_id_list = [] + for p_rank in p_rank_list: + ip_port = ip_ports[p_rank // NUM_DIE_PER_MACH] + ip, port_str = ip_port.split(':') + port = int(port_str) + (p_rank % NUM_DIE_PER_MACH) + cluster_id = ip_port_to_int(f"{ip}:{port}", prefill_tp_size) + cluster_id_list.append(cluster_id) + return cluster_id_list def _get_local_ip(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) -- Gitee From 21dc7a6fe8e1b03a70e8cc13ed9c7dcc1cb2e92b Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 25 Nov 2025 05:49:21 +0000 Subject: [PATCH 04/15] update omni/accelerators/pd/__init__.py. Signed-off-by: Yao Yunxiang --- omni/accelerators/pd/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omni/accelerators/pd/__init__.py b/omni/accelerators/pd/__init__.py index 3d96cc695..110356f97 100644 --- a/omni/accelerators/pd/__init__.py +++ b/omni/accelerators/pd/__init__.py @@ -11,7 +11,7 @@ def register(): KVConnectorFactory.register_connector( "AscendHcclConnectorV1", "omni.accelerators.pd.omni_cache_connector_v1" if os.getenv("ENABLE_OMNI_CACHE", "0") == "1" - elif "omni.accelerators.pd.llmdatadist_connector_v2" if os.getenv("ENABLE_DYNAMIC_LLMDATADIST", "0") == "1" + else "omni.accelerators.pd.llmdatadist_connector_v2" if os.getenv("ENABLE_DYNAMIC_LLMDATADIST", "0") == "1" else "omni.accelerators.pd.llmdatadist_connector_v1", "LLMDataDistConnector" ) -- Gitee From b998aad9c4bc8344731baacefef7b477f39c2aba Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 25 Nov 2025 06:55:51 +0000 Subject: [PATCH 05/15] update omni/accelerators/pd/__init__.py. Signed-off-by: Yao Yunxiang --- omni/accelerators/pd/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/omni/accelerators/pd/__init__.py b/omni/accelerators/pd/__init__.py index 66691632d..d7fd5d072 100644 --- a/omni/accelerators/pd/__init__.py +++ b/omni/accelerators/pd/__init__.py @@ -26,8 +26,9 @@ def register(): "EmsConnector", "omni.accelerators.pd.ems_connector", "EmsConnector" + ) - KVConnectorFactory.register_connector( + KVConnectorFactory.register_connector( "SwapKVConnector", "omni.accelerators.pd.swap_kv_connector", "SwapKVConnector" -- Gitee From 6cc7e5bd492180fe1b2567f0f0c8b41259ba554b Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 18:35:42 +0800 Subject: [PATCH 06/15] merge new and old datadist_connectors into one file --- omni/accelerators/pd/__init__.py | 1 - .../pd/llmdatadist_connector_v1.py | 74 +- .../pd/llmdatadist_connector_v2.py | 967 ------------------ 3 files changed, 51 insertions(+), 991 deletions(-) delete mode 100644 omni/accelerators/pd/llmdatadist_connector_v2.py diff --git a/omni/accelerators/pd/__init__.py b/omni/accelerators/pd/__init__.py index d7fd5d072..a25d3169a 100644 --- a/omni/accelerators/pd/__init__.py +++ b/omni/accelerators/pd/__init__.py @@ -11,7 +11,6 @@ def register(): KVConnectorFactory.register_connector( "AscendHcclConnectorV1", "omni.accelerators.pd.omni_cache_connector_v1" if os.getenv("ENABLE_OMNI_CACHE", "0") == "1" - else "omni.accelerators.pd.llmdatadist_connector_v2" if os.getenv("ENABLE_DYNAMIC_LLMDATADIST", "0") == "1" else "omni.accelerators.pd.llmdatadist_connector_v1", "LLMDataDistConnector" ) diff --git a/omni/accelerators/pd/llmdatadist_connector_v1.py b/omni/accelerators/pd/llmdatadist_connector_v1.py index d30045c36..1781c7b32 100644 --- a/omni/accelerators/pd/llmdatadist_connector_v1.py +++ b/omni/accelerators/pd/llmdatadist_connector_v1.py @@ -13,7 +13,7 @@ import threading import time from typing import TYPE_CHECKING, Any, Optional -import zmq +import socket from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -50,7 +50,13 @@ GET_META_MSG = b"get_meta_msg" thread_dump_path = os.environ.get("VLLM_THREAD_DUMP_PATH", "/tmp/vllm_thread_info") BLOCK_RELEASE_DELAY = int(os.environ.get("BLOCK_RELEASE_DELAY", 600)) # seconds, use to free blocks when the request is finished for a long time -from omni.accelerators.pd.llmdatadist_manager import LLMDataDistManager, LLMDataDistConfig +if os.getenv("ENABLE_DYNAMIC_LLMDATADIST", "0") == "1": + FLAG_ENABLE_DYNAMIC_LLMDATADIST = True + LLMDATADIST_BASE_PORT = int(os.environ.get("VLLM_LLMDATADIST_BASE_PORT", 15567)) + from omni.accelerators.pd.llmdatadist_manager_v1 import LLMDataDistManager, LLMDataDistConfig +else: + FLAG_ENABLE_DYNAMIC_LLMDATADIST = False + from omni.accelerators.pd.llmdatadist_manager import LLMDataDistManager, LLMDataDistConfig from omni.tools.profiler.apply_profiler_patches import patch_request patch_request() @@ -119,9 +125,16 @@ class LLMDataDistConnector(KVConnectorBase_V1): vllm_config.kv_transfer_config.kv_parallel_size = 1 logger.info("Set kv_parallel_size to 1 when use deepseek mla model.") - self.datadist_config = LLMDataDistConfig(vllm_config, ignore_load_rank=True) - self.cluster_id_start = self.datadist_config.cluster_id_start - self.host_ip = self.datadist_config.local_group.host_ip + if ENABLE_DYNAMIC_LLMDATADIST: + local_host_ip = get_local_ip() + local_host_port = LLMDATADIST_BASE_PORT + self.datadist_config = LLMDataDistConfig(vllm_config, local_host_ip, local_host_port, ignore_load_rank=True) + self.host_cluster_id = self.datadist_config.host_cluster_id + self.host_ip = local_host_ip + else: + self.datadist_config = LLMDataDistConfig(vllm_config, ignore_load_rank=True) + self.host_cluster_id = self.datadist_config.cluster_id_start + self.host_ip = self.datadist_config.local_group.host_ip # Introduce the environment variable VLLM_LLMDATADIST_ZMQ_PORT to resolve ZMQ connection conflicts during # multi-P deployments on the same machine. # This variable should not be set separately unless specifically required for this scenario. @@ -133,7 +146,7 @@ class LLMDataDistConnector(KVConnectorBase_V1): if role == KVConnectorRole.SCHEDULER: if self.is_prefill: - self.connector_scheduler = PrefillConnectorScheduler(vllm_config, self.cluster_id_start, self.host_ip, str(self.host_port), str(self.host_port + 1000)) + self.connector_scheduler = PrefillConnectorScheduler(vllm_config, self.host_cluster_id, self.host_ip, str(self.host_port), str(self.host_port + 1000)) else: self.connector_scheduler = DecodeConnectorScheduler(vllm_config, str(self.host_port + 2000)) self.connector_worker = None @@ -141,7 +154,7 @@ class LLMDataDistConnector(KVConnectorBase_V1): if self.is_prefill: self.connector_worker = PrefillConnectorWorker(vllm_config, str(self.host_ip), str(self.host_port), str(self.host_port + 1000)) else: - self.connector_worker = DecodeConnectorWorker(vllm_config, str(self.host_ip), self.cluster_id_start, str(self.host_port + 2000)) + self.connector_worker = DecodeConnectorWorker(vllm_config, str(self.host_ip), self.host_cluster_id, str(self.host_port + 2000)) self.connector_scheduler = None ############################################################ @@ -224,13 +237,13 @@ class LLMDataDistConnector(KVConnectorBase_V1): class PrefillConnectorScheduler: """Implementation of Scheduler side methods""" - def __init__(self, vllm_config, cluster_id_start: str, host_ip: str, host_port: str, trace_p_port: str): + def __init__(self, vllm_config, host_cluster_id: str, host_ip: str, host_port: str, trace_p_port: str): self.vllm_config = vllm_config - self.cluster_id_start = cluster_id_start + self.host_cluster_id = host_cluster_id self.host_ip = host_ip self.host_port = host_port self.trace_p_port=trace_p_port - logger.info("Initializing LLMDataDist Scheduler %s %s %s", cluster_id_start, host_ip, host_port) + logger.info("Initializing LLMDataDist Scheduler %s %s %s", host_cluster_id, host_ip, host_port) # initialize the dict to save requests finish time self.requests_finish_time = dict() # req_id -> headers @@ -305,7 +318,7 @@ class PrefillConnectorScheduler: return delay_free_blocks, dict( remote_block_ids=block_ids, - remote_cluster_id=self.cluster_id_start, + remote_cluster_id=self.host_cluster_id, remote_host_ip=f"tcp://{self.host_ip}:{self.host_port}", spec_token_ids=spec_token_ids, remote_dp_rank=self.vllm_config.parallel_config.data_parallel_rank, @@ -343,15 +356,18 @@ class PrefillConnectorWorker: if use_omni_attn_mgr: manager_cls = OmniBiGroupDataDistManager logger.warning(f"PrefillingConnector is using Omni datadist manager for KV transfer.") - self.datadist_manager = manager_cls(vllm_config) + if ENABLE_DYNAMIC_LLMDATADIST: + self.datadist_manager = manager_cls(vllm_config, self.host_ip, LLMDATADIST_BASE_PORT) + else: + self.datadist_manager = manager_cls(vllm_config) # initialize the dict to save requests finish time self.requests_finish_time = dict() def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.datadist_manager.register_memory(kv_caches) - self.datadist_manager.register_link() - pass + if not FLAG_ENABLE_DYNAMIC_LLMDATADIST + self.datadist_manager.register_link() def start_load_kv(self, metadata: DatadistConnectorMetadataPrefill): pass @@ -577,9 +593,9 @@ class DecodeConnectorScheduler: class DecodeConnectorWorker: """Worker implementation for datadist.""" - def __init__(self, vllm_config: "VllmConfig", host_ip: str, cluster_id_start: int, trace_d_port: str): + def __init__(self, vllm_config: "VllmConfig", host_ip: str, host_cluster_id: int, trace_d_port: str): self.vllm_config = vllm_config - self.cluster_id_start = cluster_id_start + self.host_cluster_id = host_cluster_id self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local self.tp_rank = get_tensor_model_parallel_rank() additional_config = vllm_config.additional_config @@ -605,7 +621,10 @@ class DecodeConnectorWorker: if use_omni_attn_mgr: manager_cls = OmniBiGroupDataDistManager logger.warning(f"DecodeConnector is using Omni datadist manager for KV transfer.") - self.datadist_manager = manager_cls(vllm_config) + if FLAG_ENABLE_DYNAMIC_LLMDATADIST: + self.datadist_manager = manager_cls(vllm_config, host_ip, 0) + else: + self.datadist_manager = manager_cls(vllm_config) self._recving_transfers: list = [] self._done_recving_count: defaultdict[str, int] = defaultdict(lambda: 0) @@ -706,9 +725,11 @@ class DecodeConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.datadist_manager.register_memory(kv_caches) - self.datadist_manager.register_link() + if not FLAG_ENABLE_DYNAMIC_LLMDATADIST: + self.datadist_manager.register_link() # put multi-thread_pull_kv and multi_rank_pull_kv related registered_link_infos into queues - if self.multi_rank_pull_kv or self.multi_thread_pull_kv: + # TODO: currently only support multi-thread_pull_kv and multi_rank_pull_kv for old datadist api + if self.multi_rank_pull_kv or self.multi_thread_pull_kv and not FLAG_ENABLE_DYNAMIC_LLMDATADIST: # In multi_rank_pull_kv mode, we create a thread for each P rank's cluster_id logger.info(f" ***** registered_link_infos: {self.datadist_manager.registered_link_infos}") for (cluster_id_start, prefill_dp_rank, d_rank), cluster_ids in self.datadist_manager.registered_link_infos.items(): @@ -802,7 +823,7 @@ class DecodeConnectorWorker: cluster_ids = [0] else: cluster_ids = self.datadist_manager.get_real_remote_cluster_ids(meta) - if self.multi_rank_pull_kv: + if self.multi_rank_pull_kv and not FLAG_ENABLE_DYNAMIC_LLMDATADIST: # If multi_rank_pull_kv is enabled, each DP rank will pull kv from multiple P ranks # and the cluster_ids are obtained from registered_link_infos # If the local_block_ids is a flat list of int, we can directly use it @@ -843,11 +864,12 @@ class DecodeConnectorWorker: 'local_block_ids': local_blocks, 'remote_block_ids': remote_blocks, 'remote_host_ip': meta.remote_host, + 'prefill_dp_rank': meta.remote_dp_rank, 'trace_headers': meta.trace_headers or {}, } logger.warning(f"*********** dst cluster_id is {cluster_id}.") self.queues[cluster_id].put(task) - elif self.multi_thread_pull_kv: + elif self.multi_thread_pull_kv and not FLAG_ENABLE_DYNAMIC_LLMDATADIST: task = { 'request_id': req_id, 'remote_request_id': meta.remote_request_id, @@ -855,6 +877,7 @@ class DecodeConnectorWorker: 'local_block_ids': meta.local_block_ids, 'remote_block_ids': meta.remote_block_ids, 'remote_host_ip': meta.remote_host, + 'prefill_dp_rank': meta.remote_dp_rank, 'trace_headers': meta.trace_headers or {}, } @@ -869,11 +892,12 @@ class DecodeConnectorWorker: request_id=req_id, remote_request_id=meta.remote_request_id, remote_host_ip=meta.remote_host, + prefill_dp_rank=meta.remote_dp_rank, trace_headers=meta.trace_headers, ) futures.append(future) - if not self.multi_thread_pull_kv: + if not self.multi_thread_pull_kv or FLAG_ENABLE_DYNAMIC_LLMDATADIST: for future in futures: future.add_done_callback(handle_exception) @@ -885,13 +909,17 @@ class DecodeConnectorWorker: request_id: str, remote_request_id: str, remote_host_ip: str, + prefill_dp_rank: int, trace_headers: Optional[Mapping[str, str]] = None ): start = time.time() if hasattr(self.vllm_config.model_config.hf_config, 'param_sink_with_value'): local_block_ids.insert(0, 0) remote_block_ids.insert(0, 0) - self.datadist_manager.pull_kv(remote_block_ids, local_block_ids, dst_cluster_id) + if FLAG_ENABLE_DYNAMIC_LLMDATADIST: + self.datadist_manager.pull_kv(remote_block_ids, local_block_ids, dst_cluster_id, prefill_dp_rank) + else: + self.datadist_manager.pull_kv(remote_block_ids, local_block_ids, dst_cluster_id) if self.vllm_config.parallel_config.tensor_parallel_size == 1: # tp=1, send to prefill tp rank0 directly. diff --git a/omni/accelerators/pd/llmdatadist_connector_v2.py b/omni/accelerators/pd/llmdatadist_connector_v2.py deleted file mode 100644 index 633d93837..000000000 --- a/omni/accelerators/pd/llmdatadist_connector_v2.py +++ /dev/null @@ -1,967 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. - -import json -from collections.abc import Iterator -import math -import threading -from typing import TYPE_CHECKING, Any, Optional, Union, Mapping -import zmq -import os -import pickle -import time -import socket - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.logger import logger -from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.core.sched.output import SchedulerOutput - -from omni.accelerators.pd.utils import get_config_from_dict_or_env - -if TYPE_CHECKING: - from vllm.config import VllmConfig, KVTransferConfig - from vllm.attention.backends.abstract import AttentionMetadata - from vllm.forward_context import ForwardContext - from vllm.v1.request import Request -from vllm.v1.request import Request -from vllm.utils import round_down -from dataclasses import dataclass -from collections import defaultdict -import torch -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group, get_pp_group) - -from vllm.utils import get_open_port -from vllm.v1.request import RequestStatus -import queue -from concurrent.futures import ThreadPoolExecutor -from omni.tools.profiler.trace.tracing import get_local_ip - -GET_META_MSG = b"get_meta_msg" - -thread_dump_path = os.environ.get("VLLM_THREAD_DUMP_PATH", "/tmp/vllm_thread_info") -BLOCK_RELEASE_DELAY = int(os.environ.get("BLOCK_RELEASE_DELAY", 600)) # seconds, use to free blocks when the request is finished for a long time -LLMDATADIST_BASE_PORT = int(os.environ.get("VLLM_LLMDATADIST_BASE_PORT", 15567)) - -from omni.accelerators.pd.llmdatadist_manager_v1 import LLMDataDistManager, LLMDataDistConfig -from omni.tools.profiler.apply_profiler_patches import patch_request -patch_request() - - -@dataclass -class ReqMeta: - local_block_ids: list[int] - remote_block_ids: list[int] - remote_host: str - remote_cluster_id: str - spec_token_ids: Optional[list[int]] - remote_dp_rank: Optional[int] - remote_request_id: Optional[str] - trace_headers: Optional[Mapping[str, str]] = None - -@dataclass -class ReqMetaPrefill: - finish_time: float - -class DatadistConnectorMetadata(KVConnectorMetadata): - """Metadata for datadist connector.""" - - def __init__(self): - self.requests: dict[str, ReqMeta] = {} - - def add_new_req( - self, - request_id: str, - local_block_ids: list[int], - kv_transfer_params: dict[str, Any], - trace_headers: Optional[Mapping[str, str]] = None, - ): - self.requests[request_id] = ReqMeta( - local_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params["remote_block_ids"], - remote_host=kv_transfer_params["remote_host_ip"], - remote_cluster_id=kv_transfer_params["remote_cluster_id"], - spec_token_ids=kv_transfer_params["spec_token_ids"], - remote_dp_rank=kv_transfer_params.get("remote_dp_rank", 0), - remote_request_id=kv_transfer_params.get("remote_request_id", None), - trace_headers=trace_headers or {}, - ) - -class DatadistConnectorMetadataPrefill(KVConnectorMetadata): - """Metadata for datadist connector.""" - - def __init__(self): - self.requests: dict[str, ReqMeta] = {} - - def add_new_req( - self, - request_id: str, - finish_time: float, - ): - self.requests[request_id] = ReqMeta( - finish_time=finish_time - ) - - -class LLMDataDistConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - if vllm_config.kv_transfer_config is None: - raise RuntimeError("vllm_config.kv_transfer_config cannot be None") - - if vllm_config.model_config.is_deepseek_mla: - vllm_config.kv_transfer_config.kv_parallel_size = 1 - logger.info("Set kv_parallel_size to 1 when use deepseek mla model.") - - local_host_ip = get_local_ip() - local_host_port = LLMDATADIST_BASE_PORT - self.datadist_config = LLMDataDistConfig(vllm_config, local_host_ip, local_host_port, ignore_load_rank=True) - self.host_cluster_id = self.datadist_config.host_cluster_id - self.host_ip = local_host_ip - # Introduce the environment variable VLLM_LLMDATADIST_ZMQ_PORT to resolve ZMQ connection conflicts during - # multi-P deployments on the same machine. - # This variable should not be set separately unless specifically required for this scenario. - self.host_port = get_config_from_dict_or_env(vllm_config.kv_transfer_config, "kv_port", - "VLLM_LLMDATADIST_ZMQ_PORT", "5568", int) - dp_rank = vllm_config.parallel_config.data_parallel_rank - self.host_port += dp_rank - self.is_prefill = vllm_config.kv_transfer_config.kv_role == "kv_producer" - - if role == KVConnectorRole.SCHEDULER: - if self.is_prefill: - self.connector_scheduler = PrefillConnectorScheduler(vllm_config, self.host_cluster_id, self.host_ip, str(self.host_port), str(self.host_port + 1000)) - else: - self.connector_scheduler = DecodeConnectorScheduler(vllm_config, str(self.host_port + 2000)) - self.connector_worker = None - elif role == KVConnectorRole.WORKER: - if self.is_prefill: - self.connector_worker = PrefillConnectorWorker(vllm_config, str(self.host_ip), str(self.host_port), str(self.host_port + 1000)) - else: - self.connector_worker = DecodeConnectorWorker(vllm_config, str(self.host_ip), self.host_cluster_id, str(self.host_port + 2000)) - self.connector_scheduler = None - - ############################################################ - # Scheduler Side Methods - ############################################################ - - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: - if self.connector_scheduler is None: - raise RuntimeError("self.connector_scheduler cannot be None") - return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens) - - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - if self.connector_scheduler is None: - raise RuntimeError("self.connector_scheduler cannot be None") - return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens) - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - if self.connector_scheduler is None: - raise RuntimeError("self.connector_scheduler cannot be None") - return self.connector_scheduler.build_connector_metadata(scheduler_output) - - def request_finished( - self, - request: "Request", - block_ids: list[int], - spec_token_ids: Optional[list[int]] = [] - ) -> tuple[bool, Optional[dict[str, Any]]]: - if self.connector_scheduler is None: - raise RuntimeError("self.connector_scheduler cannot be None") - return self.connector_scheduler.request_finished(request, block_ids, spec_token_ids) - - ############################################################ - # Worker Side Methods - ############################################################ - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - if self.connector_worker is None: - raise RuntimeError("self.connector_worker cannot be None") - return self.connector_worker.register_kv_caches(kv_caches) - - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: - """Get the finished recving and sending requests.""" - if self.connector_worker is None: - raise RuntimeError("self.connector_worker cannot be None") - return self.connector_worker.get_finished(self._connector_metadata) - - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - if self.connector_worker is None: - raise RuntimeError("self.connector_worker cannot be None") - if not isinstance(self._connector_metadata, Union[DatadistConnectorMetadata, DatadistConnectorMetadataPrefill]): - raise RuntimeError("self._connector_metadata must be an instance of DatadistConnectorMetadata or DatadistConnectorMetadataPrefill") - self.connector_worker.start_load_kv(self._connector_metadata) - - def wait_for_layer_load(self, layer_name: str) -> None: - """Connector does not do layerwise saving.""" - pass - - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - """Connector does not save explicitly.""" - pass - - def wait_for_save(self): - """Connector does not save explicitly.""" - pass - - def pop_trace_headers(self, req_id: str) -> dict: - if self.connector_scheduler is None: - raise RuntimeError("self.connector_scheduler cannot be None") - return self.connector_scheduler.pop_trace_headers(req_id) - -class PrefillConnectorScheduler: - """Implementation of Scheduler side methods""" - - def __init__(self, vllm_config, host_cluster_id: str, host_ip: str, host_port: str, trace_p_port: str): - self.vllm_config = vllm_config - self.host_cluster_id = host_cluster_id - self.host_ip = host_ip - self.host_port = host_port - self.trace_p_port=trace_p_port - logger.info("Initializing LLMDataDist Scheduler %s %s %s", host_cluster_id, host_ip, host_port) - # initialize the dict to save requests finish time - self.requests_finish_time = dict() - # req_id -> headers - self.sending_trace_headers: dict[str, dict] = {} - - self._transfer_lock = threading.Lock() - self.ctx = zmq.Context() - self.pull_socket = self.ctx.socket(zmq.PULL) - self.pull_socket.bind(f"tcp://{host_ip}:{trace_p_port}") - if os.getenv("PROFILING_NAMELIST", None): - self._listener_thread = threading.Thread(target=self._listen_worker_headers, daemon=True) - self._listener_thread.start() - - def _listen_worker_headers(self): - while True: - try: - msg_str = self.pull_socket.recv_string() - msg_list = json.loads(msg_str) - for msg in msg_list: - req_id = msg['remote_request_id'] or msg['request_id'] - headers = msg.get('trace_headers', {}) - with self._transfer_lock: - self.sending_trace_headers[req_id] = headers - except Exception as e: - logger.error(f"Failed to receive worker header (P): {e}") - time.sleep(1) - - def pop_trace_headers(self, req_id: str) -> dict: - with self._transfer_lock: - trace_headers = self.sending_trace_headers.pop(req_id, {}) - if trace_headers: - return trace_headers - - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: - return 0, False - - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - pass - - def build_connector_metadata( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - metadata = DatadistConnectorMetadataPrefill() - # add requests finish time to metadata, to pass to worker connector - metadata.requests = {req_id: ReqMetaPrefill(finish_time=finish_time) - for req_id, finish_time in self.requests_finish_time.items()} - self.requests_finish_time.clear() - return metadata - - def request_finished( - self, - request: "Request", - block_ids: list[int], - spec_token_ids: Optional[list[int]] = [] - ) -> tuple[bool, Optional[dict[str, Any]]]: - """ - Once a request is finished, determine whether request blocks - should be freed now or will be sent asynchronously and freed later. - """ - if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: - return False, None - - delay_free_blocks = len(block_ids) > 0 - # record the finish time of the request - if delay_free_blocks: - self.requests_finish_time[request.request_id] = time.monotonic() - - return delay_free_blocks, dict( - remote_block_ids=block_ids, - remote_cluster_id=self.host_cluster_id, - remote_host_ip=f"tcp://{self.host_ip}:{self.host_port}", - spec_token_ids=spec_token_ids, - remote_dp_rank=self.vllm_config.parallel_config.data_parallel_rank, - remote_request_id=request.request_id - ) - - -class PrefillConnectorWorker: - """Implementation of Worker side methods""" - - def __init__(self, vllm_config: "VllmConfig", host_ip: str, host_port: str, trace_p_port: str): - # Metadata. - self.host_ip = host_ip - self.host_port = host_port - self.trace_p_port = trace_p_port - self.rank = get_tensor_model_parallel_rank() - if self.rank == 0 and get_pp_group().is_last_rank: - self.ctx = zmq.Context() - self.input_socket = self.ctx.socket(zmq.constants.PULL) - self.input_socket.bind(f"tcp://{self.host_ip}:{self.host_port}") - logger.info(f"ConnectWorker bind tcp://{self.host_ip}:{self.host_port}") - self._transfer_lock = threading.Lock() - self.receive_req_list = [] - thread_name = "prefill_connector_get_pulled_kv_req_list" - self.thread = threading.Thread(target=self.get_pulled_kv_req_list, daemon=True, name=thread_name) - self.thread.start() - dump_thread_to_file(self.thread, thread_name, thread_dump_path) - - # check whether omni attention is enabled - manager_cls = LLMDataDistManager - if vllm_config.additional_config and "enable_omni_attn" in vllm_config.additional_config: - # do import only when necessary - from omni.accelerators.cache import OmniBiGroupDataDistManager, check_omni_attn_cmd_arg - use_omni_attn_mgr = check_omni_attn_cmd_arg(vllm_config.additional_config) - if use_omni_attn_mgr: - manager_cls = OmniBiGroupDataDistManager - logger.warning(f"PrefillingConnector is using Omni datadist manager for KV transfer.") - local_host_port = LLMDATADIST_BASE_PORT - self.datadist_manager = manager_cls(vllm_config, self.host_ip, local_host_port) - - # initialize the dict to save requests finish time - self.requests_finish_time = dict() - - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - self.datadist_manager.register_memory(kv_caches) - - def start_load_kv(self, metadata: DatadistConnectorMetadataPrefill): - pass - - def get_finished(self, metadata: DatadistConnectorMetadataPrefill) -> tuple[set[str], set[str]]: - """ - Get requests that are done sending or recving. - """ - all_done_sending: set[str] = set() - all_done_recving: set[str] = set() - if self.rank == 0 and get_pp_group().is_last_rank: - # Update requests_finish_time with new finish times from metadata - with self._transfer_lock: - self.requests_finish_time.update( - {req_id: meta.finish_time for req_id, meta in metadata.requests.items()} - ) - current_time = time.monotonic() - # Identify requests whose finish time exceeds BLOCK_RELEASE_DELAY - out_date_reqs = [] - for req_id, finish_time in self.requests_finish_time.items(): - if current_time - finish_time > BLOCK_RELEASE_DELAY: - out_date_reqs.append(req_id) - else: - # Since the dict is ordered by finish_time, we can break early - break - for req_id in out_date_reqs: - logger.warning( - f"Request {req_id} is out of date, finish time: {self.requests_finish_time[req_id]}. Freeing blocks now." - ) - all_done_sending.add(req_id) - del self.requests_finish_time[req_id] - - if len(self.receive_req_list) == 0: - return all_done_sending, all_done_recving - - with self._transfer_lock: - for item in self.receive_req_list: - req_id = item.get('remote_request_id')#item['remote_request_id'] - headers = item.get('trace_headers', {}) - logger.debug(f"Get_finished: request {req_id}") - all_done_sending.add(req_id) - # if the request's kv has been received, remove it from requests_finish_time - if req_id in self.requests_finish_time: - del self.requests_finish_time[req_id] - self.receive_req_list.clear() - - return all_done_sending, all_done_recving - - def get_pulled_kv_req_list(self): - path_p = f"tcp://{self.host_ip}:{self.trace_p_port}" - socket_p = self.ctx.socket(zmq.PUSH) - socket_p.connect(path_p) - while True: - try: - if self.input_socket.poll(timeout=10) > 0: - message = self.input_socket.recv_string() - id_list = json.loads(message) # Parse the received JSON string into a list - logger.debug("Received: %s", id_list) - with self._transfer_lock: - self.receive_req_list.extend(id_list) - if os.getenv("PROFILING_NAMELIST", None): - json_data = json.dumps(id_list) - socket_p.send_string(json_data) - except Exception as e: - logger.error("get pulled kv req list failed: %s", e) - - -class DecodeConnectorScheduler: - """Implementation of Scheduler side methods""" - def __init__(self, vllm_config: VllmConfig, trace_d_port: str): - self.vllm_config = vllm_config - self.block_size = vllm_config.cache_config.block_size - self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} - self.processed_request: set[str] = set() - self.ctx = zmq.Context() - self.zmq_socket_map = {} - - self.host_ip = get_local_ip() - self.trace_d_port = trace_d_port - self.recving_trace_headers: dict[str, dict] = {} - self._transfer_lock = threading.Lock() - self.ctx = zmq.Context() - self.pull_socket = self.ctx.socket(zmq.PULL) - self.pull_socket.bind(f"tcp://{self.host_ip}:{self.trace_d_port}") - - if os.getenv("PROFILING_NAMELIST", None): - self._listener_thread = threading.Thread(target=self._listen_worker_headers, daemon=True) - self._listener_thread.start() - - additional_config = vllm_config.additional_config - if additional_config: - self.async_pull_kv = additional_config.get("async_pull_kv", False) - else: - self.async_pull_kv = False - - if self.async_pull_kv: - self.context = zmq.Context() - self.pub = self.context.socket(zmq.PUB) - kv_rank = self.vllm_config.kv_transfer_config.kv_rank - self.pub.bind(f"ipc:///tmp/sched-pub-{kv_rank}-{vllm_config.parallel_config.data_parallel_rank_local}") - - def _listen_worker_headers(self): - while True: - try: - msg_str = self.pull_socket.recv_string() - msg_list = json.loads(msg_str) - for msg in msg_list: - req_id = msg['remote_request_id'] - headers = msg.get('trace_headers', {}) - with self._transfer_lock: - self.recving_trace_headers[req_id] = headers - except Exception as e: - logger.error(f"Failed to receive worker header (D): {e}") - time.sleep(1) - - def pop_trace_headers(self, req_id: str) -> dict: - with self._transfer_lock: - trace_headers = self.recving_trace_headers.pop(req_id, {}) - if trace_headers: - return trace_headers - - def _send_pulled_kv_req_list(self, path, data): - if path in self.zmq_socket_map: - socket = self.zmq_socket_map[path] - else: - socket = self.ctx.socket(zmq.PUSH) - socket.connect(path) - self.zmq_socket_map[path] = socket - logger.info(f"create new socket path:{path}") - - try: - json_data = json.dumps(data) - socket.send_string(json_data) - logger.info(f"send string {json_data} path:{path}") - except Exception as e: - logger.error(f"Failed to send reqest_id {json_data} to prefill: {e}") - - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: - if request.request_id in self.processed_request: - return 0, False - params = request.kv_transfer_params - if params is None: - return 0, False - logger.debug( - "DatadistConnector get_num_new_matched_tokens: " - "num_computed_tokens=%s, kv_transfer_params=%s", - num_computed_tokens, params) - - if num_computed_tokens % self.block_size != 0: - raise RuntimeError("num_computed_tokens must be divisible by self.block_size") - rounded_num_prompt_tokens = self._round_up( - len(request.prompt_token_ids), self.block_size) - count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) - return count, count > 0 - - def _round_up(self, x: int, y: int) -> int: - return ((x + y - 1) // y) * y - - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - logger.debug(f"Request id {request.request_id}: blocks length is {len(blocks.blocks)}") - params = request.kv_transfer_params - logger.debug( - "DatadistConnector update_state_after_alloc: " - "num_external_tokens=%s, kv_transfer_params=%s", - num_external_tokens, params) - - self.processed_request.add(request.request_id) - if params is not None: - if params.get("remote_block_ids"): - if all(p in params for p in ("remote_cluster_id", "remote_host_ip")): - self._reqs_need_recv[request.request_id] = ( - request, blocks.get_unhashed_block_ids()) - else: - logger.warning( - "Got invalid KVTransferParams: %s.", params) - - def build_connector_metadata( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - metadata = DatadistConnectorMetadata() - for req_id, (req, block_ids) in self._reqs_need_recv.items(): - if req.kv_transfer_params is None: - logger.warning(f"For reuqest {req_id}: kv_transfer_params now is None") - else: - metadata.add_new_req( - request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - trace_headers=req.trace_headers or {}, - ) - req.kv_transfer_params = None - self._reqs_need_recv.clear() - - if self.async_pull_kv: - if scheduler_output is None: - # Let go fast path - if metadata.requests: - serialized_data = pickle.dumps(metadata) - self.pub.send(serialized_data) - - return metadata - - def request_finished( - self, - request: "Request", - block_ids: list[int], - spec_token_ids: Optional[list[int]] = [] - ) -> tuple[bool, Optional[dict[str, Any]]]: - if request.request_id in self.processed_request: - self.processed_request.remove(request.request_id) - if request.status == RequestStatus.FINISHED_ABORTED and request.kv_transfer_params is not None: - self._send_pulled_kv_req_list(request.kv_transfer_params.get("remote_host_ip"), [{'request_id': request.request_id, 'trace_headers': request.trace_headers or {}}]) - return False, None - - -class DecodeConnectorWorker: - """Worker implementation for datadist.""" - - def __init__(self, vllm_config: "VllmConfig", host_ip: str, host_cluster_id: int, trace_d_port: str): - self.vllm_config = vllm_config - self.host_cluster_id = host_cluster_id - self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local - self.tp_rank = get_tensor_model_parallel_rank() - additional_config = vllm_config.additional_config - if additional_config: - self.async_pull_kv = additional_config.get("async_pull_kv", False) - self.multi_thread_pull_kv = additional_config.get("multi_thread_pull_kv", False) - self.multi_rank_pull_kv = additional_config.get("multi_rank_pull_kv", False) - else: - self.async_pull_kv = False - self.multi_thread_pull_kv = False - self.multi_rank_pull_kv = False - if self.multi_rank_pull_kv: - self.multi_thread_pull_kv = True - if vllm_config.parallel_config.tensor_parallel_size > 1 and self.multi_rank_pull_kv: - raise ValueError("multi_rank_pull_kv are not supported when tp > 1.") - - # check whether omni attention is enabled - manager_cls = LLMDataDistManager - if vllm_config.additional_config and "enable_omni_attn" in vllm_config.additional_config: - # do import only when necessary - from omni.accelerators.cache import OmniBiGroupDataDistManager, check_omni_attn_cmd_arg - use_omni_attn_mgr = check_omni_attn_cmd_arg(vllm_config.additional_config) - if use_omni_attn_mgr: - manager_cls = OmniBiGroupDataDistManager - logger.warning(f"DecodeConnector is using Omni datadist manager for KV transfer.") - self.datadist_manager = manager_cls(vllm_config, host_ip, 0) - - self._recving_transfers: list = [] - self._done_recving_count: defaultdict[str, int] = defaultdict(lambda: 0) - - self._pull_kv_lock = threading.Lock() - self.queues = {} # cluster_id -> queue.Queue - self.threads = {} # cluster_id -> threading.Thread - - self._transfer_lock = threading.Lock() - self.host_ip = host_ip - self.trace_d_port = trace_d_port - - self.ctx = zmq.Context() - self.zmq_socket_map = {} - - if self.async_pull_kv: - # dp_rank = vllm_config.parallel_config.data_parallel_rank_local - thread_name = f"async_pull_kv_{self.dp_rank}" - self.thread_on_fast_path_req = threading.Thread(target=self.on_fast_path_req, daemon=True, name=thread_name) - self.thread_on_fast_path_req.start() - logger.warning(f"DecodeConnectorWorker initialized with self.async_pull_kv enabled.") - - # Write thread name and native_id to file - dump_thread_to_file(self.thread_on_fast_path_req, thread_name, thread_dump_path) - - if self.multi_thread_pull_kv and self.vllm_config.parallel_config.tensor_parallel_size > 1: - self.tp_sync_path = f"ipc:///tmp/tp-sync-dp{self.vllm_config.parallel_config.data_parallel_rank}" - if get_tensor_model_parallel_rank() == 0: - self.input_socket = self.ctx.socket(zmq.constants.PULL) - self.input_socket.bind(self.tp_sync_path) - logger.info(f"ConnectWorker bind {self.tp_sync_path}") - - self.tp_sync_req_dict = {} - thread_name = f"decode_connector_sync_pulled_tp_kvcache_and_send_dp{self.vllm_config.parallel_config.data_parallel_rank}" - self.sync_thread = threading.Thread(target=self.sync_pulled_tp_kvcache_and_send, daemon=True, - name=thread_name) - self.sync_thread.start() - dump_thread_to_file(self.sync_thread, thread_name, thread_dump_path) - - def sync_pulled_tp_kvcache_and_send(self): - while True: - try: - if self.input_socket.poll(timeout=10) > 0: - data = self.input_socket.recv_json() - request_id = data[0].get("request_id") - remote_host_ip = data[0].get("remote_host_ip") - remote_request_id = data[0].get("remote_request_id", None) - trace_headers = data[0].get("trace_headers", {}) - # if request_id not in dict, set to 0, else do nothing - self.tp_sync_req_dict.setdefault(request_id, 0) - self.tp_sync_req_dict[request_id] += 1 - logger.debug(f"{request_id} finish pull kv {self.tp_sync_req_dict[request_id]} times.") - if self.tp_sync_req_dict[request_id] == self.vllm_config.parallel_config.tensor_parallel_size: - self.tp_sync_req_dict.pop(request_id) - self._send_pulled_kv_req_list(remote_host_ip, [{'remote_request_id': remote_request_id, 'trace_headers': trace_headers or {}}]) - with self._transfer_lock: - self._recving_transfers.append(request_id) - except Exception as e: - logger.error("Sync pulled kv when tp > 1 and send failed: %s", e) - - def on_fast_path_req(self): - context = zmq.Context() - sub = context.socket(zmq.SUB) - kv_rank = self.vllm_config.kv_transfer_config.kv_rank - sub.connect(f"ipc:///tmp/sched-pub-{kv_rank}-{self.vllm_config.parallel_config.data_parallel_rank_local}") - sub.setsockopt_string(zmq.SUBSCRIBE, "") - - while True: - serialized_data = sub.recv() - metadata = pickle.loads(serialized_data) - for req_id, meta in metadata.requests.items(): - if (len(meta.local_block_ids) > 0) and (len(meta.remote_block_ids) > 0): - self.start_load_kv(metadata) - if self.tp_rank == 0: - logger.info( - "Received fast path request for request %s with " - "local_block_ids: %s, remote_block_ids: %s.", - req_id, - len(meta.local_block_ids), - len(meta.remote_block_ids) - ) - - def worker(self, cluster_id): - q = self.queues[cluster_id] - time.sleep(0) - while True: - task = q.get() - if task is None: - continue - try: - self._read_blocks(**task) - except Exception as e: - logger.error("KV transfer task failed in thread %s: %s", cluster_id, e) - patch_data = [{'request_id': task['request_id'], 'trace_headers': task.get('trace_headers', {})}] - self._send_pulled_kv_req_list(task['remote_host_ip'], patch_data) - raise RuntimeError(f"Failed to pull kv for request:{task['request_id']} from cluster:{cluster_id}.") - q.task_done() - - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - self.datadist_manager.register_memory(kv_caches) - # TODO:put multi-thread_pull_kv and multi_rank_pull_kv related registered_link_infos into queues - # In single thread pull kv mode, we use a single thread to pull kv - logger.info(" ***** Using single thread to pull kv.") - max_concurrents = 1 - self.executor = ThreadPoolExecutor(max_workers=max_concurrents) - - logger.debug("Finish register_kv_caches.") - - # Now go asynchronous pull_kv - def start_load_kv(self, metadata: DatadistConnectorMetadata): - logger.debug(f" ***** start_load_kv: {len(metadata.requests)}") - futures = [] - for req_id, meta in metadata.requests.items(): - # if the local_block_ids is empty, skip pulling kv for the request - if len(meta.local_block_ids) == 0: - if self.tp_rank == 0: - logger.info(f" ***** Request {req_id} has 0 local blocks, skip load kv.") - continue - # If local_block_ids is a flat list of int, omni-attention is not used - # and we can directly use the local_block_ids and remote_block_ids - if isinstance(meta.local_block_ids[0], int): - # local_block_ids (kv blocks in D) is more than remote_block_ids (kv blocks in P) - # leaded by lookahead num, which is used by eagle and multi step - if len(meta.remote_block_ids) < len(meta.local_block_ids): - meta.local_block_ids = meta.local_block_ids[:len(meta.remote_block_ids)] - logger.debug("look ahead token num is greater than 0") - # If remote_block_ids is more than local_block_ids, we only need the last N remote blocks - # where N is the number of local blocks - elif len(meta.remote_block_ids) > len(meta.local_block_ids): - meta.remote_block_ids = meta.remote_block_ids[-len(meta.local_block_ids):] - if self.tp_rank == 0: - logger.info( - " ***** start_load_kv for request %s " - "Num local_block_ids: %s. Num remote_block_ids: %s.", - req_id, - len(meta.local_block_ids), - len(meta.remote_block_ids) - ) - # If local_block_ids is a list of lists (e.g., [[], []]), omni-attention is used - # local_block_ids[0] is a list of local block ids for uncompressed layers - # local_block_ids[1] is a list of local block ids for compressed layers - elif isinstance(meta.local_block_ids[0], list): - # If local_block_ids[0] is a list of lists, we need to ensure that remote_block_ids - # is a list of lists as well, where each sublist corresponds to the local_block - meta.remote_block_ids = [meta.remote_block_ids] * len(meta.local_block_ids) - # If local_block_ids[0] is empty, skip pulling kv for the request - if len(meta.local_block_ids[0]) == 0: - if self.tp_rank == 0: - logger.info(f" ***** Request {req_id} has 0 local blocks, skip load kv.") - continue - # remote_block_ids in P is less than local_block_ids[0] in D, - # leaded by lookahead num, which is used by eagle and multi step - elif len(meta.remote_block_ids[0]) < len(meta.local_block_ids[0]): - meta.local_block_ids[0] = meta.local_block_ids[0][:len(meta.remote_block_ids[0])] - logger.debug("look ahead token num is greater than 0") - # If remote_block_ids in P is more than local_block_ids[0] in D, we only need the last N remote blocks - elif len(meta.remote_block_ids[0]) > len(meta.local_block_ids[0]): - meta.remote_block_ids[0] = meta.remote_block_ids[0][-len(meta.local_block_ids[0]):] - if self.tp_rank == 0: - logger.info( - " ***** start_load_kv for request %s " - "Num local_block_ids: %s. Num remote_block_ids: %s.", - req_id, - len(meta.local_block_ids[0]), - len(meta.remote_block_ids[0]) - ) - # handle the unexpected case where local_block_ids is not a list of int or list of lists - else: - logger.error(f"Unexpected type for meta.local_block_ids[0]: {type(meta.local_block_ids[0])}") - raise RuntimeError(f"Unexpected type for meta.local_block_ids[0]: {type(meta.local_block_ids[0])}") - cluster_ids = self.datadist_manager.get_real_remote_cluster_ids(meta) - if self.multi_rank_pull_kv: - # If multi_rank_pull_kv is enabled, each DP rank will pull kv from multiple P ranks - # and the cluster_ids are obtained from registered_link_infos - # If the local_block_ids is a flat list of int, we can directly use it - # As multi_rank_pull_kv is designed to pull kv from two P ranks, - # we split the local_block_ids and remote_block_ids into two parts - if not isinstance(meta.local_block_ids[0], list): - block_thre = len(meta.local_block_ids) // 2 - # If the local_block_ids is a flat list of list, only split the blocks for uncompressed layers - else: - block_thre = len(meta.local_block_ids[0]) // 2 - for idx_cluster, cluster_id in enumerate(cluster_ids): - if not isinstance(meta.local_block_ids[0], list): - if idx_cluster == 0: - local_blocks = meta.local_block_ids[:block_thre] - remote_blocks = meta.remote_block_ids[:block_thre] - len_local_blocks = len(local_blocks) - else: - local_blocks = meta.local_block_ids[block_thre:] - remote_blocks = meta.remote_block_ids[block_thre:] - len_local_blocks = len(local_blocks) - else: - if idx_cluster == 0: - # For uncompressed layers, split the local_block_ids[0] and remote_block_ids - # For compressed layers, only pull kv from the second P rank - local_blocks = [meta.local_block_ids[0][:block_thre], []] - # remote_blocks need to be split as well for getting kv blocks for compressed layers in P - remote_blocks = [meta.remote_block_ids[0][:block_thre], []] - len_local_blocks = len(local_blocks[0]) - else: - local_blocks = [meta.local_block_ids[0][block_thre:], meta.local_block_ids[1]] - remote_blocks = [meta.remote_block_ids[0][block_thre:], meta.remote_block_ids[1]] - len_local_blocks = len(local_blocks[0]) - if len_local_blocks > 0: - task = { - 'request_id': req_id, - 'remote_request_id': meta.remote_request_id, - 'dst_cluster_id': cluster_id, - 'local_block_ids': local_blocks, - 'remote_block_ids': remote_blocks, - 'remote_host_ip': meta.remote_host, - 'prefill_dp_rank': meta.remote_dp_rank, - 'trace_headers': meta.trace_headers or {}, - } - logger.warning(f"*********** dst cluster_id is {cluster_id}.") - self.queues[cluster_id].put(task) - elif self.multi_thread_pull_kv: - task = { - 'request_id': req_id, - 'remote_request_id': meta.remote_request_id, - 'dst_cluster_id': cluster_ids[0], - 'local_block_ids': meta.local_block_ids, - 'remote_block_ids': meta.remote_block_ids, - 'remote_host_ip': meta.remote_host, - 'prefill_dp_rank': meta.remote_dp_rank, - 'trace_headers': meta.trace_headers or {}, - } - - self.queues[cluster_ids[0]].put(task) - else: - # Use ThreadPoolExecutor to handle the task - future = self.executor.submit( - self._read_blocks, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - dst_cluster_id=cluster_ids[0], - request_id=req_id, - remote_request_id=meta.remote_request_id, - remote_host_ip=meta.remote_host, - prefill_dp_rank=meta.remote_dp_rank, - trace_headers=meta.trace_headers, - ) - futures.append(future) - - if not self.multi_thread_pull_kv: - for future in futures: - future.add_done_callback(handle_exception) - - def _read_blocks( - self, - local_block_ids: list[int], - remote_block_ids: list[int], - dst_cluster_id: str, - request_id: str, - remote_request_id: str, - remote_host_ip: str, - prefill_dp_rank: int, - trace_headers: Optional[Mapping[str, str]] = None - ): - start = time.time() - self.datadist_manager.pull_kv(remote_block_ids, local_block_ids, dst_cluster_id, prefill_dp_rank) - - if self.vllm_config.parallel_config.tensor_parallel_size == 1: - # tp=1, send to prefill tp rank0 directly. - self._send_pulled_kv_req_list(remote_host_ip, [{'remote_request_id': remote_request_id, 'trace_headers': trace_headers or {}}]) - with self._transfer_lock: - self._recving_transfers.append(request_id) - else: - if self.multi_thread_pull_kv: - # tp>1, send to decode to rank0 firstly. - self._send_pulled_kv_req_list( - self.tp_sync_path, - { - "request_id": request_id, - "remote_request_id": remote_request_id, - "remote_host_ip": remote_host_ip, - 'trace_headers': trace_headers or {} - } - ) - else: - torch.distributed.barrier(group=get_tp_group().cpu_group) - if get_tensor_model_parallel_rank() == 0: - self._send_pulled_kv_req_list(remote_host_ip, [{'remote_request_id': remote_request_id, 'trace_headers': trace_headers or {}}]) - with self._transfer_lock: - self._recving_transfers.append(request_id) - logger.debug(f" ***** read block, req_id:{request_id}, local_block_ids:{local_block_ids}, remote_block_ids:{remote_block_ids}") - cost = time.time() - start - if self.tp_rank == 0: - logger.info(f" ***** read block, req_id:{request_id}, cost:{cost:.6f}") - - - def _send_pulled_kv_req_list(self, path, data): - if path in self.zmq_socket_map: - socket = self.zmq_socket_map[path] - else: - socket = self.ctx.socket(zmq.PUSH) - socket.connect(path) - self.zmq_socket_map[path] = socket - logger.info(f"create new socket path:{path}") - - path_d = f"tcp://{self.host_ip}:{self.trace_d_port}" - socket_d = self.ctx.socket(zmq.PUSH) - socket_d.connect(path_d) - - try: - json_data = json.dumps(data) - socket.send_string(json_data) - logger.info(f"send string {json_data} path:{path}") - if os.getenv("PROFILING_NAMELIST", None): - socket_d.send_string(json_data) - except Exception as e: - logger.error(f"Failed to send reqest_id {json_data} to prefill: {e}") - - def get_finished(self, metadata: DatadistConnectorMetadata) -> tuple[set[str], set[str]]: - # for decode size, done_sending is no need - all_done_sending: set[str] = set() - with self._transfer_lock: - all_done_recving = self._pop_done_transfers(self._recving_transfers) - if len(all_done_recving) > 0: - logger.debug( - "Get_finished: %s requests done recving", len(all_done_recving)) - - return all_done_sending, all_done_recving - - def _pop_done_transfers(self, transfers: list) -> set[str]: - done_req_ids: set[str] = set() - for req_id in transfers: - done_req_ids.add(req_id) - self._recving_transfers.clear() - return done_req_ids - -def handle_exception(future): - if future.exception(): - logger.error(f"Exception occurred in future: {future.exception()}") - raise future.exception() - -def dump_thread_to_file(thread, thread_name: str, folder_path: str): - - timeout = 5 # seconds - start_time = time.time() - while not hasattr(thread, "native_id"): - if time.time() - start_time > timeout: - logger.error(f"Timeout waiting for thread {thread_name} to have native_id.") - return - time.sleep(0.005) - - # Ensure the folder exists - if not os.path.exists(folder_path): - try: - os.makedirs(folder_path, exist_ok=True) - except Exception as e: - logger.error(f"Failed to create folder {folder_path}: {e}") - return - - file_path = os.path.join(folder_path, thread_name) - try: - with open(file_path, "w", encoding="utf-8") as f: - f.write(str(thread.native_id)) - except Exception as e: - logger.error(f"Failed to write thread info to {file_path}: {e}") -- Gitee From 797d5e338ae44fd18a995f1c4bb115cfff868d0e Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 18:47:22 +0800 Subject: [PATCH 07/15] resolve confilcts in pd.py for omni-attn --- omni/accelerators/cache/pd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omni/accelerators/cache/pd.py b/omni/accelerators/cache/pd.py index d9e432c15..b75ca712f 100644 --- a/omni/accelerators/cache/pd.py +++ b/omni/accelerators/cache/pd.py @@ -119,7 +119,7 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): used_ids.add(cur_id) prompt_cache_key = BlocksCacheKey( - prompt_cluster_id=prompt_cluster_id, model_id=cur_id) + prompt_cluster_id=prompt_cluster_id + cluster_id_pp_offset, model_id=cur_id) if flag == 0: if USE_NEW_DATADIST: ret = self._pull_blocks(prompt_cache_key, kv_cache, -- Gitee From 03535acce3342eddff8f2967f16b7d30d9b52ba8 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 18:50:24 +0800 Subject: [PATCH 08/15] fix typos --- omni/accelerators/cache/pd.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/omni/accelerators/cache/pd.py b/omni/accelerators/cache/pd.py index b75ca712f..123fd3214 100644 --- a/omni/accelerators/cache/pd.py +++ b/omni/accelerators/cache/pd.py @@ -26,12 +26,12 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): super().__init__(vllm_config, local_host_ip, host_port) else: super().__init__(vllm_config) - self.registerd_kv_caches: list[list[Cache]] = [[], []] + self.registered_kv_caches: list[list[Cache]] = [[], []] @override def register_memory(self, kv_caches: dict[str, torch.Tensor]): - if any(len(group_cache) > 0 for group_cache in self.registerd_kv_caches): - raise ValueError("Attr `registerd_kv_caches` must be empty before register kv_caches.") + if any(len(group_cache) > 0 for group_cache in self.registered_kv_caches): + raise ValueError("Attr `registered_kv_caches` must be empty before register kv_caches.") # NOTE: flatten_kv_caches is a nested list like [[k1,k2,...,kL], [v1,v2,...,vL]] # if KV is just one tensor, then it's [[kv1,kv2,...,kvL]] flatten_kv_caches: list[list[torch.Tensor]] = unzip_kv_cache_dict(kv_caches) @@ -69,7 +69,7 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): for model_id, sub_kv_caches in enumerate(flatten_kv_caches): # sub_kv_caches is a list of Tensors, whose length is number of layers - for flag in range(len(self.registerd_kv_caches)): + for flag in range(len(self.registered_kv_caches)): group_kv_caches = [sub_kv_caches[j] for j in layer_idx[flag]] cache_desc = CacheDesc(num_tensors=len(group_kv_caches), shape=tuple(group_kv_caches[0].shape), data_type=TORCH_DTYPE_TO_NPU_DTYPE[group_kv_caches[0].dtype]) @@ -88,8 +88,8 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): cache_key = None cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, cache_key) - self.registerd_kv_caches[flag].append(cache) - logger.error(f" ***** registerd_kv_caches num:{sum([len(group_kv_caches) for group_kv_caches in self.registerd_kv_caches])}") + self.registered_kv_caches[flag].append(cache) + logger.error(f" ***** registered_kv_caches num:{sum([len(group_kv_caches) for group_kv_caches in self.registered_kv_caches])}") @override def pull_kv(self, src_blocks: list[int], tgt_blocks: list[list[int]], prompt_cluster_id: int, prefill_dp_rank: int=0): @@ -106,13 +106,13 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): torch.npu.set_device(f"npu:{self.local_rank}") sink, recent = itfc.SINK, itfc.RECENT omni_max_blocks = sink + recent - N = len(self.registerd_kv_caches[0]) + N = len(self.registered_kv_caches[0]) used_ids = set() - for flag in range(len(self.registerd_kv_caches)): + for flag in range(len(self.registered_kv_caches)): group_src_blocks: list[int] = src_blocks[flag] group_tgt_blocks: list[int] = tgt_blocks[flag] - for model_id, kv_cache in enumerate(self.registerd_kv_caches[flag]): + for model_id, kv_cache in enumerate(self.registered_kv_caches[flag]): cur_id = flag * N + model_id if cur_id in used_ids: raise RuntimeError(f"Error! ID already pulled. {N=}, {model_id=}, {used_ids=}, {cur_id=}.") -- Gitee From 82e66a0d071ce61eb878479af671efe59186ae88 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 18:59:49 +0800 Subject: [PATCH 09/15] resolve the conflicts due to the introduction of pp in omni-attn --- omni/accelerators/cache/pd.py | 110 +++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 35 deletions(-) diff --git a/omni/accelerators/cache/pd.py b/omni/accelerators/cache/pd.py index 123fd3214..494f57ed4 100644 --- a/omni/accelerators/cache/pd.py +++ b/omni/accelerators/cache/pd.py @@ -37,9 +37,20 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): flatten_kv_caches: list[list[torch.Tensor]] = unzip_kv_cache_dict(kv_caches) num_layers = len(flatten_kv_caches[0]) + PATTERN = itfc.PATTERN # partition layer indices into full and omni - full_layer_idx = [i for i in range(num_layers) if itfc.PATTERN[i] == 0] - omni_layer_idx = [i for i in range(num_layers) if itfc.PATTERN[i] == 1] + if self.data_dist_config.is_prefill: + # 1. 获取rank + # 2. 拿到对应的stage的layer_start和end + # 3. PATTERN + pp_rank = self.rank // self.prefill_tp_dp_size + prefill_pp_partitions = self.data_dist_config.kv_producer_pp_partitions + pp_start_layer_idx = sum(prefill_pp_partitions[:pp_rank]) + pp_end_layer_idx = pp_start_layer_idx + prefill_pp_partitions[pp_rank] + PATTERN = itfc.PATTERN[pp_start_layer_idx : pp_end_layer_idx] + + full_layer_idx = [i for i in range(num_layers) if PATTERN[i] == 0] + omni_layer_idx = [i for i in range(num_layers) if PATTERN[i] == 1] layer_idx = [full_layer_idx, omni_layer_idx] # check validity @@ -62,6 +73,13 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): logger.warning("Trying to register grouped KV caches for OMNI attention, with " f"{len(full_layer_idx)} full attn layers and {len(omni_layer_idx)} omni attn layers.") + if self.data_dist_config.is_prefill: + self._register_caches_prefill(flatten_kv_caches, layer_idx) + else: + self._register_caches_decode(flatten_kv_caches, layer_idx) + logger.error(f" ***** registered_kv_caches num:{sum([len(group_kv_caches) for group_kv_caches in self.registered_kv_caches])}") + + def _register_caches_prefill(self, flatten_kv_caches, layer_idx): # model_id related N = len(flatten_kv_caches) used_ids = set() @@ -75,21 +93,43 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): data_type=TORCH_DTYPE_TO_NPU_DTYPE[group_kv_caches[0].dtype]) cache_addrs = [int(item.data_ptr()) for item in group_kv_caches] - if self.data_dist_config.is_prefill: - # NOTE: when assigning model_id to cache_key, we consider KV group information - # e.g., if registered_kv_caches = [[K_full, V_full], [K_omni, V_omni]] - # then model_ids should be [[0, 1], [2, 3]] - cur_id = flag * N + model_id - if cur_id in used_ids: - raise RuntimeError(f"Error! ID already used. {N=}, {model_id=}, {used_ids=}, {cur_id=}.") - used_ids.add(cur_id) - cache_key = BlocksCacheKey(self.data_dist_engine.cluster_id, model_id=cur_id) - else: - cache_key = None + # NOTE: when assigning model_id to cache_key, we consider KV group information + # e.g., if registered_kv_caches = [[K_full, V_full], [K_omni, V_omni]] + # then model_ids should be [[0, 1], [2, 3]] + cur_id = flag * N + model_id + if cur_id in used_ids: + raise RuntimeError(f"Error! ID already used. {N=}, {model_id=}, {used_ids=}, {cur_id=}.") + used_ids.add(cur_id) + cache_key = BlocksCacheKey(self.data_dist_engine.cluster_id, model_id=cur_id) cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, cache_key) self.registered_kv_caches[flag].append(cache) - logger.error(f" ***** registered_kv_caches num:{sum([len(group_kv_caches) for group_kv_caches in self.registered_kv_caches])}") + + def _register_caches_decode(self, flatten_kv_caches, layer_idx): + prefill_pp_partitions = self.data_dist_config.kv_producer_pp_partitions + for flag in range(len(self.registered_kv_caches)): + cnt_layer_num = 0 + layer_idx_start = 0 + for cur_pp_stage_layer_num in prefill_pp_partitions: + cur_pp_stage_kv_caches = [] + layer_idx_end = layer_idx_start + while layer_idx_end < len(layer_idx[flag]) and cnt_layer_num <= layer_idx[flag][layer_idx_end] < cnt_layer_num + cur_pp_stage_layer_num: + layer_idx_end += 1 + flag_stage_layer_idx = layer_idx[flag][layer_idx_start : layer_idx_end] + layer_idx_start = layer_idx_end + for sub_kv_caches in flatten_kv_caches: + # sub_kv_caches is a list of Tensors, whose length is number of layers + # flag_stage_layer_idx = layer_idx[flag][cnt_layer_num : cnt_layer_num + cur_pp_stage_layer_num] + group_kv_caches = [sub_kv_caches[j] for j in flag_stage_layer_idx] + # group_kv_caches = [sub_kv_caches[j] for j in layer_idx[flag] if cnt_layer_num <= j < cnt_layer_num + cur_pp_stage_layer_num] + cache_desc = CacheDesc(num_tensors=len(group_kv_caches), shape=tuple(group_kv_caches[0].shape), + data_type=TORCH_DTYPE_TO_NPU_DTYPE[group_kv_caches[0].dtype]) + cache_addrs = [int(item.data_ptr()) for item in group_kv_caches] + + cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, None) + cur_pp_stage_kv_caches.append(cache) + self.registered_kv_caches[flag].append(cur_pp_stage_kv_caches) + cnt_layer_num += cur_pp_stage_layer_num @override def pull_kv(self, src_blocks: list[int], tgt_blocks: list[list[int]], prompt_cluster_id: int, prefill_dp_rank: int=0): @@ -107,38 +147,38 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): sink, recent = itfc.SINK, itfc.RECENT omni_max_blocks = sink + recent N = len(self.registered_kv_caches[0]) - used_ids = set() + # used_ids = set() for flag in range(len(self.registered_kv_caches)): group_src_blocks: list[int] = src_blocks[flag] group_tgt_blocks: list[int] = tgt_blocks[flag] - for model_id, kv_cache in enumerate(self.registered_kv_caches[flag]): - cur_id = flag * N + model_id - if cur_id in used_ids: - raise RuntimeError(f"Error! ID already pulled. {N=}, {model_id=}, {used_ids=}, {cur_id=}.") - used_ids.add(cur_id) + for pp_stage_ind, cur_pp_stage_kv_caches in enumerate(self.registered_kv_caches[flag]): + for model_id, kv_cache in enumerate(cur_pp_stage_kv_caches): + cur_id = flag * N + model_id + cluster_id_pp_offset = pp_stage_ind * self.prefill_tp_dp_size - prompt_cache_key = BlocksCacheKey( - prompt_cluster_id=prompt_cluster_id + cluster_id_pp_offset, model_id=cur_id) + prompt_cache_key = BlocksCacheKey( + prompt_cluster_id=prompt_cluster_id + cluster_id_pp_offset, model_id=cur_id) if flag == 0: if USE_NEW_DATADIST: ret = self._pull_blocks(prompt_cache_key, kv_cache, group_src_blocks, group_tgt_blocks) else: self._pull_blocks(prompt_cache_key, kv_cache, - group_src_blocks, group_tgt_blocks) - else: - if len(group_tgt_blocks) == 0: - continue - tmp_src, tmp_tgt = group_src_blocks, group_tgt_blocks - if len(group_src_blocks) < omni_max_blocks: - tmp_tgt = group_tgt_blocks[:len(group_src_blocks)] - elif len(group_src_blocks) > omni_max_blocks: - tmp_src = group_src_blocks[:sink] + group_src_blocks[-recent:] - if len(tmp_src) != len(tmp_tgt): - raise RuntimeError("src and tgt cannot match for omni kv caches. " - f"{src_blocks=}, {tgt_blocks=}, " - f"{len(tmp_src)=}, {len(tmp_tgt)=}.") + group_src_blocks, group_tgt_blocks) + else: + if len(group_tgt_blocks) == 0: + continue + tmp_src, tmp_tgt = group_src_blocks, group_tgt_blocks + if len(group_src_blocks) < omni_max_blocks: + tmp_tgt = group_tgt_blocks[:len(group_src_blocks)] + elif len(group_src_blocks) > omni_max_blocks: + tmp_src = group_src_blocks[:sink] + group_src_blocks[-recent:] + if len(tmp_src) != len(tmp_tgt): + raise RuntimeError("src and tgt cannot match for omni kv caches. " + f"{src_blocks=}, {tgt_blocks=}, " + f"{tmp_src=}, {tmp_tgt=}, " + f"{len(tmp_src)=}, {len(tmp_tgt)=}.") if USE_NEW_DATADIST: ret = self._pull_blocks(prompt_cache_key, kv_cache, tmp_src, tmp_tgt) -- Gitee From c8698d8ea79cb7fcfa0db8bf13161d5b780ef745 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 19:01:15 +0800 Subject: [PATCH 10/15] resolve the conflicts due to the introduction of pp in omni-attn --- omni/accelerators/cache/pd.py | 62 +++++++++++++++++------------------ 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/omni/accelerators/cache/pd.py b/omni/accelerators/cache/pd.py index 494f57ed4..6f17a1260 100644 --- a/omni/accelerators/cache/pd.py +++ b/omni/accelerators/cache/pd.py @@ -159,36 +159,36 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): prompt_cache_key = BlocksCacheKey( prompt_cluster_id=prompt_cluster_id + cluster_id_pp_offset, model_id=cur_id) - if flag == 0: - if USE_NEW_DATADIST: - ret = self._pull_blocks(prompt_cache_key, kv_cache, - group_src_blocks, group_tgt_blocks) - else: - self._pull_blocks(prompt_cache_key, kv_cache, + if flag == 0: + if USE_NEW_DATADIST: + ret = self._pull_blocks(prompt_cache_key, kv_cache, group_src_blocks, group_tgt_blocks) - else: - if len(group_tgt_blocks) == 0: - continue - tmp_src, tmp_tgt = group_src_blocks, group_tgt_blocks - if len(group_src_blocks) < omni_max_blocks: - tmp_tgt = group_tgt_blocks[:len(group_src_blocks)] - elif len(group_src_blocks) > omni_max_blocks: - tmp_src = group_src_blocks[:sink] + group_src_blocks[-recent:] - if len(tmp_src) != len(tmp_tgt): - raise RuntimeError("src and tgt cannot match for omni kv caches. " - f"{src_blocks=}, {tgt_blocks=}, " - f"{tmp_src=}, {tmp_tgt=}, " - f"{len(tmp_src)=}, {len(tmp_tgt)=}.") + else: + self._pull_blocks(prompt_cache_key, kv_cache, + group_src_blocks, group_tgt_blocks) + else: + if len(group_tgt_blocks) == 0: + continue + tmp_src, tmp_tgt = group_src_blocks, group_tgt_blocks + if len(group_src_blocks) < omni_max_blocks: + tmp_tgt = group_tgt_blocks[:len(group_src_blocks)] + elif len(group_src_blocks) > omni_max_blocks: + tmp_src = group_src_blocks[:sink] + group_src_blocks[-recent:] + if len(tmp_src) != len(tmp_tgt): + raise RuntimeError("src and tgt cannot match for omni kv caches. " + f"{src_blocks=}, {tgt_blocks=}, " + f"{tmp_src=}, {tmp_tgt=}, " + f"{len(tmp_src)=}, {len(tmp_tgt)=}.") + if USE_NEW_DATADIST: + ret = self._pull_blocks(prompt_cache_key, kv_cache, + tmp_src, tmp_tgt) + else: + self._pull_blocks(prompt_cache_key, kv_cache, + tmp_src, tmp_tgt) if USE_NEW_DATADIST: - ret = self._pull_blocks(prompt_cache_key, kv_cache, - tmp_src, tmp_tgt) - else: - self._pull_blocks(prompt_cache_key, kv_cache, - tmp_src, tmp_tgt) - if USE_NEW_DATADIST: - if not ret: - self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) - ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, - tmp_src, tmp_tgt) - if not ret_updated: - raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") + if not ret: + self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) + ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, + tmp_src, tmp_tgt) + if not ret_updated: + raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") -- Gitee From ca2bb29bc9554dbbac002d255f83a8a08a2c68ae Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 19:02:47 +0800 Subject: [PATCH 11/15] resolve the conflicts due to the introduction of pp in omni-attn --- omni/accelerators/cache/pd.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/omni/accelerators/cache/pd.py b/omni/accelerators/cache/pd.py index 6f17a1260..87f3218db 100644 --- a/omni/accelerators/cache/pd.py +++ b/omni/accelerators/cache/pd.py @@ -166,19 +166,19 @@ class OmniBiGroupDataDistManager(LLMDataDistManager): else: self._pull_blocks(prompt_cache_key, kv_cache, group_src_blocks, group_tgt_blocks) - else: - if len(group_tgt_blocks) == 0: - continue - tmp_src, tmp_tgt = group_src_blocks, group_tgt_blocks - if len(group_src_blocks) < omni_max_blocks: - tmp_tgt = group_tgt_blocks[:len(group_src_blocks)] - elif len(group_src_blocks) > omni_max_blocks: - tmp_src = group_src_blocks[:sink] + group_src_blocks[-recent:] - if len(tmp_src) != len(tmp_tgt): - raise RuntimeError("src and tgt cannot match for omni kv caches. " - f"{src_blocks=}, {tgt_blocks=}, " - f"{tmp_src=}, {tmp_tgt=}, " - f"{len(tmp_src)=}, {len(tmp_tgt)=}.") + else: + if len(group_tgt_blocks) == 0: + continue + tmp_src, tmp_tgt = group_src_blocks, group_tgt_blocks + if len(group_src_blocks) < omni_max_blocks: + tmp_tgt = group_tgt_blocks[:len(group_src_blocks)] + elif len(group_src_blocks) > omni_max_blocks: + tmp_src = group_src_blocks[:sink] + group_src_blocks[-recent:] + if len(tmp_src) != len(tmp_tgt): + raise RuntimeError("src and tgt cannot match for omni kv caches. " + f"{src_blocks=}, {tgt_blocks=}, " + f"{tmp_src=}, {tmp_tgt=}, " + f"{len(tmp_src)=}, {len(tmp_tgt)=}.") if USE_NEW_DATADIST: ret = self._pull_blocks(prompt_cache_key, kv_cache, tmp_src, tmp_tgt) -- Gitee From 307d1ddf85bacab3d7d148244cbcad29d1f80eb1 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 19:07:41 +0800 Subject: [PATCH 12/15] resolve the conflicts in __init__.py in pd --- omni/accelerators/pd/__init__.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/omni/accelerators/pd/__init__.py b/omni/accelerators/pd/__init__.py index a25d3169a..11168b5b7 100644 --- a/omni/accelerators/pd/__init__.py +++ b/omni/accelerators/pd/__init__.py @@ -10,8 +10,19 @@ from vllm.distributed.kv_transfer.kv_connector.factory import \ def register(): KVConnectorFactory.register_connector( "AscendHcclConnectorV1", - "omni.accelerators.pd.omni_cache_connector_v1" if os.getenv("ENABLE_OMNI_CACHE", "0") == "1" - else "omni.accelerators.pd.llmdatadist_connector_v1", + ( + "omni.accelerators.pd.omni_cache_connector_d2p" + if os.getenv("ENABLE_D_SIDE_FIRST", "0") == "1" and os.getenv("ENABLE_OMNI_CACHE", "0") == "1" + else ( + "omni.accelerators.pd.llmdatadist_connector_d2p" + if os.getenv("ENABLE_D_SIDE_FIRST", "0") == "1" + else ( + "omni.accelerators.pd.omni_cache_connector_v1" + if os.getenv("ENABLE_OMNI_CACHE", "0") == "1" + else "omni.accelerators.pd.llmdatadist_connector_v1" + ) + ) + ), "LLMDataDistConnector" ) @@ -26,7 +37,7 @@ def register(): "omni.accelerators.pd.ems_connector", "EmsConnector" ) - + KVConnectorFactory.register_connector( "SwapKVConnector", "omni.accelerators.pd.swap_kv_connector", -- Gitee From 2c6a627acdbe29f7c1edf28d41fa6dbf0bd5c3d4 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 19:33:08 +0800 Subject: [PATCH 13/15] add pp support to datadist_manager_v1 --- .../accelerators/pd/llmdatadist_manager_v1.py | 122 ++++++++++++++---- 1 file changed, 100 insertions(+), 22 deletions(-) diff --git a/omni/accelerators/pd/llmdatadist_manager_v1.py b/omni/accelerators/pd/llmdatadist_manager_v1.py index a987dc1a3..2db568639 100644 --- a/omni/accelerators/pd/llmdatadist_manager_v1.py +++ b/omni/accelerators/pd/llmdatadist_manager_v1.py @@ -5,6 +5,7 @@ import json import time from collections import defaultdict, namedtuple from functools import cached_property +from typing import Optional import llm_datadist import torch @@ -64,6 +65,28 @@ RETRYABLE_CODES = [ NUM_DIE_PER_MACH = int(os.getenv("NUM_DIE_PER_MACH", "16")) +def get_kv_producer_pp_partitions(num_hidden_layers: int, pp_size: int, num_mtp_layers: int = 0, kv_producer_pp_partitions_str: Optional[str] = None) -> list[int]: + if kv_producer_pp_partitions_str is not None and kv_producer_pp_partitions_str != "null": + try: + partitions = [ + int(layer) for layer in kv_producer_pp_partitions_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format(kv_producer_pp_partitions_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"len(partitions)={len(partitions)} does not match pp_size={pp_size}") + if sum(partitions) != num_hidden_layers: + raise ValueError(f"sum(partitions)={sum(partitions)} does not match num_hidden_layers={num_hidden_layers}") + else: + layers_per_partition = num_hidden_layers // pp_size + partitions = [layers_per_partition for _ in range(pp_size)] + remaining_layers = num_hidden_layers % pp_size + if remaining_layers: + for i in range(2, remaining_layers + 2): + partitions[-i] += 1 + partitions[-1] += num_mtp_layers + return partitions + class LLMDataDistConfig: """ Configuration for the separate deployment. @@ -96,6 +119,13 @@ class LLMDataDistConfig: # will be used in d side to checkout which P rank is selected to build kv link self.kv_parallel_size = self.kv_transfer_config.kv_parallel_size self.kv_producer_dp_size = self.kv_transfer_config.kv_connector_extra_config.get("kv_producer_dp_size", 1) + self.kv_producer_pp_size = self.kv_transfer_config.kv_connector_extra_config.get("kv_producer_pp_size", 1) + hf_config = vllm_config.model_config.hf_config + num_mtp_layers = getattr(hf_config, 'num_nextn_predict_layers', getattr(hf_config, 'num_mtp_layers', getattr(hf_config, 'n_predict', 0))) + self.kv_producer_pp_partitions = get_kv_producer_pp_partitions(hf_config.num_hidden_layers, + self.kv_producer_pp_size, + num_mtp_layers, + self.kv_transfer_config.kv_connector_extra_config.get("kv_producer_pp_partitions", None)) host_ip_list = self._get_worker_ips() self.host_ip_list = host_ip_list @@ -326,29 +356,44 @@ class LLMDataDistManager: def pull_kv(self, src_blocks, tgt_blocks, prompt_cluster_id, prefill_dp_rank): """ pull kv from remote cache to local cache, support to refresh link when pull kv fails """ - torch.npu.set_device(f"npu:{self.local_rank}") - for model_id, kv_cache in enumerate(self.registered_kv_caches): - prompt_cache_key = BlocksCacheKey( - prompt_cluster_id=prompt_cluster_id, model_id=model_id) - ret = self._pull_blocks(prompt_cache_key, kv_cache, + if os.getenv("ENABLE_PD_MOCKUP", "0") == "1": + return + torch.npu.set_device(f"npu:{self.local_rank}") + if self.data_dist_config.kv_producer_pp_size > 1: + for pp_stage_ind, cur_pp_stage_kv_caches in enumerate(self.registered_kv_caches): + for model_id, kv_cache in enumerate(cur_pp_stage_kv_caches): + cluster_id_pp_offset = pp_stage_ind * self.prefill_tp_dp_size + prompt_cache_key = BlocksCacheKey( + prompt_cluster_id=prompt_cluster_id + cluster_id_pp_offset, model_id=model_id + ) + ret = self._pull_blocks(prompt_cache_key, kv_cache, + src_blocks, tgt_blocks) + if not ret: + self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) + ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, + src_blocks, tgt_blocks) + if not ret_updated: + raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") + else: + for model_id, kv_cache in enumerate(self.registered_kv_caches): + prompt_cache_key = BlocksCacheKey( + prompt_cluster_id=prompt_cluster_id, model_id=model_id) + ret = self._pull_blocks(prompt_cache_key, kv_cache, src_blocks, tgt_blocks) - if not ret: - logger.warning(f"======= failed pull kv with {prompt_cluster_id=} ========") - self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) - logger.warning(f"======= successfully rebuild kv link with {prompt_cluster_id=} ========") - ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, - src_blocks, tgt_blocks) - if not ret_updated: - raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") + if not ret: + self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) + + ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, + src_blocks, tgt_blocks) + if not ret_updated: + raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") def _refresh_link(self, prompt_cluster_id, prefill_dp_rank, d_rank): """ refresh the kv link: unlink + link """ - logger.warning(f"======= refresh_link with {prompt_cluster_id=} ========") (host_cluster_id, prefill_dp_rank, d_rank) = \ self._get_host_cluster_id(prompt_cluster_id, prefill_dp_rank, d_rank) if host_cluster_id is not None: self.close_link(host_cluster_id, prefill_dp_rank, d_rank) - logger.warning(f"======= rebuild_link with {prompt_cluster_id=} ========") self.register_link(host_cluster_id, prefill_dp_rank, d_rank) else: raise RuntimeError(f"Unregistered host cluster id!!!") @@ -416,22 +461,55 @@ class LLMDataDistManager: # spec model. flatten_kv_caches = maybe_split_kv_caches_for_spec_layers(flatten_kv_caches) + if self.data_dist_config.kv_producer_pp_size > 1: + if self.data_dist_config.is_prefill: + self._register_caches_prefill(flatten_kv_caches) + else: + self._register_caches_decode(flatten_kv_caches) + else: + for model_id, sub_kv_caches in enumerate(flatten_kv_caches): + cache_desc = CacheDesc(num_tensors=len(sub_kv_caches), shape=tuple(sub_kv_caches[0].shape), + data_type=TORCH_DTYPE_TO_NPU_DTYPE[sub_kv_caches[0].dtype]) + + cache_addrs = [int(item.data_ptr()) for item in sub_kv_caches] + + if self.data_dist_config.is_prefill: + cache_key = BlocksCacheKey(self.data_dist_engine.cluster_id, model_id=model_id) + else: + cache_key = None + + cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, cache_key) + self.registered_kv_caches.append(cache) + + logger.info(f" ***** registered_kv_caches num:{len(self.registered_kv_caches)}") + + def _register_caches_prefill(self, flatten_kv_caches): for model_id, sub_kv_caches in enumerate(flatten_kv_caches): cache_desc = CacheDesc(num_tensors=len(sub_kv_caches), shape=tuple(sub_kv_caches[0].shape), - data_type=TORCH_DTYPE_TO_NPU_DTYPE[sub_kv_caches[0].dtype]) + data_type=TORCH_DTYPE_TO_NPU_DTYPE[sub_kv_caches[0].dtype]) cache_addrs = [int(item.data_ptr()) for item in sub_kv_caches] - if self.data_dist_config.is_prefill: - cache_key = BlocksCacheKey(self.data_dist_engine.cluster_id, model_id=model_id) - else: - cache_key = None + cache_key = BlocksCacheKey(self.data_dist_engine.cluster_id, model_id=model_id) cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, cache_key) self.registered_kv_caches.append(cache) - logger.debug(f" ***** registered_kv_caches num:{len(self.registered_kv_caches)}") -# reuse the existing code + + def _register_caches_decode(self, flatten_kv_caches): + prefill_pp_partitions = self.data_dist_config.kv_producer_pp_partitions + cnt_layer_num = 0 + for cur_pp_stage_layer_num in prefill_pp_partitions: + cur_pp_stage_kv_caches = [] + for origin_sub_kv_caches in flatten_kv_caches: + sub_kv_caches = origin_sub_kv_caches[cnt_layer_num : cnt_layer_num + cur_pp_stage_layer_num] + cache_desc = CacheDesc(num_tensors=len(sub_kv_caches), shape=tuple(sub_kv_caches[0].shape), data_type=TORCH_DTYPE_TO_NPU_DTYPE[sub_kv_caches[0].dtype]) + cache_addrs = [int(item.data_ptr()) for item in sub_kv_caches] + cache = self.data_dist_engine.cache_manager.register_blocks_cache(cache_desc, cache_addrs, None) + cur_pp_stage_kv_caches.append(cache) + self.registered_kv_caches.append(cur_pp_stage_kv_caches) + cnt_layer_num += cur_pp_stage_layer_num + def unzip_kv_cache_dict(kv_caches: dict[str, torch.Tensor], ): # Convert kv_caches dict to a list of tensors in the order of layer_index. _, first_kv_cache = next(iter(kv_caches.items())) -- Gitee From b2b633338451b8dc42a2407e3a1f876d047a1485 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Tue, 2 Dec 2025 14:20:53 +0000 Subject: [PATCH 14/15] update omni/accelerators/pd/llmdatadist_connector_v1.py. fix one typo Signed-off-by: Yao Yunxiang --- omni/accelerators/pd/llmdatadist_connector_v1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/omni/accelerators/pd/llmdatadist_connector_v1.py b/omni/accelerators/pd/llmdatadist_connector_v1.py index 1781c7b32..462ff7e51 100644 --- a/omni/accelerators/pd/llmdatadist_connector_v1.py +++ b/omni/accelerators/pd/llmdatadist_connector_v1.py @@ -125,7 +125,7 @@ class LLMDataDistConnector(KVConnectorBase_V1): vllm_config.kv_transfer_config.kv_parallel_size = 1 logger.info("Set kv_parallel_size to 1 when use deepseek mla model.") - if ENABLE_DYNAMIC_LLMDATADIST: + if FLAG_ENABLE_DYNAMIC_LLMDATADIST: local_host_ip = get_local_ip() local_host_port = LLMDATADIST_BASE_PORT self.datadist_config = LLMDataDistConfig(vllm_config, local_host_ip, local_host_port, ignore_load_rank=True) @@ -356,7 +356,7 @@ class PrefillConnectorWorker: if use_omni_attn_mgr: manager_cls = OmniBiGroupDataDistManager logger.warning(f"PrefillingConnector is using Omni datadist manager for KV transfer.") - if ENABLE_DYNAMIC_LLMDATADIST: + if FLAG_ENABLE_DYNAMIC_LLMDATADIST: self.datadist_manager = manager_cls(vllm_config, self.host_ip, LLMDATADIST_BASE_PORT) else: self.datadist_manager = manager_cls(vllm_config) @@ -366,7 +366,7 @@ class PrefillConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.datadist_manager.register_memory(kv_caches) - if not FLAG_ENABLE_DYNAMIC_LLMDATADIST + if not FLAG_ENABLE_DYNAMIC_LLMDATADIST: self.datadist_manager.register_link() def start_load_kv(self, metadata: DatadistConnectorMetadataPrefill): -- Gitee From a8660773685d91d36a657d9ebb6015c0e1b64601 Mon Sep 17 00:00:00 2001 From: Yao Yunxiang Date: Wed, 3 Dec 2025 14:33:17 +0800 Subject: [PATCH 15/15] revise _pull_kv and sleep to be async func, to avoid thread blocking --- .../accelerators/pd/llmdatadist_manager_v1.py | 112 +++++++++++++----- 1 file changed, 84 insertions(+), 28 deletions(-) diff --git a/omni/accelerators/pd/llmdatadist_manager_v1.py b/omni/accelerators/pd/llmdatadist_manager_v1.py index 2db568639..f11505cdf 100644 --- a/omni/accelerators/pd/llmdatadist_manager_v1.py +++ b/omni/accelerators/pd/llmdatadist_manager_v1.py @@ -27,6 +27,9 @@ from vllm.distributed.parallel_state import (get_tp_group, get_dp_group, get_wor import os import socket import struct +import asyncio +import inspect +from concurrent.futures import TimeoutError as FutureTimeoutError logger = init_logger(__name__) @@ -323,42 +326,55 @@ class LLMDataDistManager: self.registered_link_infos.pop((host_cluster_id, prefill_dp_rank, d_rank), None) logger.info(f"rank:{self.rank} unlinked with : {remote_host_ip}, {prompt_cluster_id_list=}") - def _pull_blocks(self, src_cache_key, dst_cache, src_blocks, dst_blocks): - """" pull kv from remote cache to local cache, support return error state if pull kv fails """ + async def _pull_blocks(self, src_cache_key, dst_cache, src_blocks, dst_blocks): + """Pull kv from remote cache to local cache; return False on failure.""" + pull_sync = self.data_dist_engine.cache_manager.pull_blocks + for attempt in range(KV_CACHE_RETRY_TIMES): try: - self.data_dist_engine.cache_manager.pull_blocks( - src_cache_key, dst_cache, src_blocks, dst_blocks - ) + if hasattr(asyncio, 'to_thread'): + await asyncio.to_thread( + pull_sync, src_cache_key, dst_cache, src_blocks, dst_blocks + ) + else: + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, pull_sync, src_cache_key, dst_cache, src_blocks, dst_blocks + ) return True + + except asyncio.CancelledError: + raise RuntimeError("Pull blocks operation cancelled") + except LLMException as e: - code = e.status_code + code = getattr(e, "status_code", None) if code in RETRYABLE_CODES: logger.info( - f"kv cache pull blocks failed, need retry" - f"(attempt {attempt + 1}/{KV_CACHE_RETRY_TIMES}): {e}" + "kv cache pull blocks failed, retry (%d/%d): %s", + attempt + 1, KV_CACHE_RETRY_TIMES, e ) if attempt < KV_CACHE_RETRY_TIMES - 1: - time.sleep(KV_CACHE_RETRY_WAIT_SECOND) + await asyncio.sleep(KV_CACHE_RETRY_WAIT_SECOND) continue - logger.error( - f"kv cache pull blocks failed after {KV_CACHE_RETRY_TIMES} attempts: {e}" - ) + logger.error("kv cache pull blocks failed after %d attempts: %s", + KV_CACHE_RETRY_TIMES, e) return False else: - logger.error(f"kv cache pull blocks failed (non-retryable): {e}") + logger.error("kv cache pull blocks failed (non-retryable): %s", e) return False + except (TypeError, ValueError) as e: - logger.error(f"kv cache pull blocks input error: {e}") + logger.error("kv cache pull blocks input error: %s", e) return False - logger.error("kv cache pull blocks exhausted attempts without success") + + logger.error("kv cache pull blocks exhausted attempts") return False def pull_kv(self, src_blocks, tgt_blocks, prompt_cluster_id, prefill_dp_rank): """ pull kv from remote cache to local cache, support to refresh link when pull kv fails """ if os.getenv("ENABLE_PD_MOCKUP", "0") == "1": return - torch.npu.set_device(f"npu:{self.local_rank}") + torch.npu.set_device(f"npu:{self.local_rank}") if self.data_dist_config.kv_producer_pp_size > 1: for pp_stage_ind, cur_pp_stage_kv_caches in enumerate(self.registered_kv_caches): for model_id, kv_cache in enumerate(cur_pp_stage_kv_caches): @@ -366,27 +382,23 @@ class LLMDataDistManager: prompt_cache_key = BlocksCacheKey( prompt_cluster_id=prompt_cluster_id + cluster_id_pp_offset, model_id=model_id ) - ret = self._pull_blocks(prompt_cache_key, kv_cache, - src_blocks, tgt_blocks) + ret = self._run_coro_sync(lambda: self._pull_blocks(prompt_cache_key, kv_cache, src_blocks, tgt_blocks)) if not ret: self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) - ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, - src_blocks, tgt_blocks) + + ret_updated = self._run_coro_sync(lambda: self._pull_blocks(prompt_cache_key, kv_cache, src_blocks, tgt_blocks)) if not ret_updated: - raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") + raise RuntimeError("Failed to pull kv even if rebuild the kv link!") else: for model_id, kv_cache in enumerate(self.registered_kv_caches): - prompt_cache_key = BlocksCacheKey( - prompt_cluster_id=prompt_cluster_id, model_id=model_id) - ret = self._pull_blocks(prompt_cache_key, kv_cache, - src_blocks, tgt_blocks) + prompt_cache_key = BlocksCacheKey(prompt_cluster_id=prompt_cluster_id, model_id=model_id) + ret = self._run_coro_sync(lambda: self._pull_blocks(prompt_cache_key, kv_cache, src_blocks, tgt_blocks)) if not ret: self._refresh_link(prompt_cluster_id, prefill_dp_rank, self.rank) - ret_updated = self._pull_blocks(prompt_cache_key, kv_cache, - src_blocks, tgt_blocks) + ret_updated = self._run_coro_sync(lambda: self._pull_blocks(prompt_cache_key, kv_cache, src_blocks, tgt_blocks)) if not ret_updated: - raise RuntimeError(f"Failed to pull kv even if rebuild the kv link!") + raise RuntimeError("Failed to pull kv even if rebuild the kv link!") def _refresh_link(self, prompt_cluster_id, prefill_dp_rank, d_rank): """ refresh the kv link: unlink + link """ @@ -510,6 +522,50 @@ class LLMDataDistManager: self.registered_kv_caches.append(cur_pp_stage_kv_caches) cnt_layer_num += cur_pp_stage_layer_num + def _run_coro_sync(self, coro_or_callable, timeout: float | None = None): + # Normalize to coroutine or synchronous result + coro = None + # If caller passed a coroutine object already + if inspect.iscoroutine(coro_or_callable): + coro = coro_or_callable + # If caller passed an "async def" function (coroutine function) + elif inspect.iscoroutinefunction(coro_or_callable): + coro = coro_or_callable() + # If it's any callable, call it and inspect the return value + elif callable(coro_or_callable): + result = coro_or_callable() + if inspect.iscoroutine(result): + coro = result + else: + # It's a synchronous result (not a coroutine) — return it directly. + return result + else: + raise ValueError(f"Expected coroutine or callable, got {type(coro_or_callable)}") + + # At this point `coro` is a coroutine object (awaitable). + # Decide how to run it from sync context. + try: + running_loop = asyncio.get_running_loop() + except RuntimeError: + running_loop = None + + if running_loop is None: + # No loop in current thread: safe to create a temporary loop and run it. + return asyncio.run(coro) + else: + # There is a running loop in current thread -> cannot use asyncio.run. + # need a separate loop (e.g. self._main_loop) running in another thread. + main_loop = getattr(self, "_main_loop", None) + if main_loop is None or not getattr(main_loop, "is_running", lambda: False)(): + raise RuntimeError("Detected a running asyncio event loop in current thread.") + + future = asyncio.run_coroutine_threadsafe(coro, main_loop) + try: + return future.result(timeout=timeout) + except FutureTimeoutError: + future.cancel() + raise + def unzip_kv_cache_dict(kv_caches: dict[str, torch.Tensor], ): # Convert kv_caches dict to a list of tensors in the order of layer_index. _, first_kv_cache = next(iter(kv_caches.items())) -- Gitee