12 Star 37 Fork 29

openEuler/yuanrong-datasystem

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
0001-implement-yr-datasystem-connector-and-support-multimoda.patch 70.29 KB
一键复制 编辑 原始数据 按行查看 历史
yaohaolin 提交于 2025-10-24 21:26 +08:00 . sync codes
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780
From ea9b3d00989ebff4ae54a47212c996edb08fe5b0 Mon Sep 17 00:00:00 2001
From: yangsonglin <yangsonglin0821@163.com>
Date: Mon, 1 Sep 2025 07:37:42 +0000
Subject: [PATCH] implement yr-datasystem connector and support multimodal
---
.github/workflows/vllm_ascend_test_pd.yaml | 5 +-
tests/e2e/pd_disaggreate/yuanrong/README.md | 155 ++++
.../pd_disaggreate/yuanrong/clean_yuanrong.sh | 9 +
.../yuanrong/run_pd_instances.sh | 44 +
.../yuanrong/run_proxy_server.sh | 14 +
.../pd_disaggreate/yuanrong/run_yuanrong.sh | 24 +
.../yuanrong/simple_pd_proxy_server.py | 212 +++++
.../yuanrong/test_yuanrong_connector.py | 141 +++
vllm_ascend/attention/attention_v1.py | 31 +
vllm_ascend/core/schedule_config.py | 6 +-
vllm_ascend/core/scheduler.py | 92 +-
vllm_ascend/distributed/__init__.py | 4 +
vllm_ascend/distributed/yuanrong_connector.py | 828 ++++++++++++++++++
13 files changed, 1558 insertions(+), 7 deletions(-)
create mode 100644 tests/e2e/pd_disaggreate/yuanrong/README.md
create mode 100644 tests/e2e/pd_disaggreate/yuanrong/clean_yuanrong.sh
create mode 100644 tests/e2e/pd_disaggreate/yuanrong/run_pd_instances.sh
create mode 100644 tests/e2e/pd_disaggreate/yuanrong/run_proxy_server.sh
create mode 100644 tests/e2e/pd_disaggreate/yuanrong/run_yuanrong.sh
create mode 100644 tests/e2e/pd_disaggreate/yuanrong/simple_pd_proxy_server.py
create mode 100644 tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py
create mode 100644 vllm_ascend/distributed/yuanrong_connector.py
diff --git a/.github/workflows/vllm_ascend_test_pd.yaml b/.github/workflows/vllm_ascend_test_pd.yaml
index a86ba60..cd13509 100644
--- a/.github/workflows/vllm_ascend_test_pd.yaml
+++ b/.github/workflows/vllm_ascend_test_pd.yaml
@@ -108,4 +108,7 @@ jobs:
- name: Run vllm-project/vllm-ascend PD Disaggregation edge test
run: |
- bash tests/e2e/pd_disaggreate/run_edge_case_test.sh
\ No newline at end of file
+ bash tests/e2e/pd_disaggreate/run_edge_case_test.sh
+ - name: Run vllm-project/vllm-ascend PD Disaggregation test with YuanRong Connector
+ run: |
+ pytest -sv tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py
\ No newline at end of file
diff --git a/tests/e2e/pd_disaggreate/yuanrong/README.md b/tests/e2e/pd_disaggreate/yuanrong/README.md
new file mode 100644
index 0000000..427919c
--- /dev/null
+++ b/tests/e2e/pd_disaggreate/yuanrong/README.md
@@ -0,0 +1,155 @@
+# Overview: transfer KVCache through host memory with YuanRong connector
+
+### Dataflow
+
+
+|----------------------- prefill node ---------------------| &nbsp; |---------------------- decode node ---------------------|
+
+Prefill Instance -----> YuanRongConnector -----> YuanRong Data Worker -----> YuanRongConnector -----> Decode Instance
+
+|----- kv on npu -----| &nbsp; |----- kv offload to host -----| &nbsp; |----- kv transfer by host net -----| &nbsp; |----- kv load to npu -----|
+
+### Pros
+- Network jitter and failures are handled outside of the vLLM process, better isolation and fault tolerance
+- No need to allocate communication buffers on NPU, enable a larger sequence batch and throughput
+- Work seamlessly with features those require offloading kvcache to host memory or SSD, like prefix cache, priority scheduling, RAG, etc.
+### Cons
+- Higher transfer latency compared with device-to-device transfer, not optimal for latency-sensitive scenarios
+
+
+
+
+
+# Installation
+
+## Install etcd
+#### 1. Download the latest binaries from [etcd github releases](https://github.com/etcd-io/etcd/releases)
+```
+ETCD_VERSION="v3.5.12"
+wget https://github.com/etcd-io/etcd/releases/download/${ETCD_VERSION}/etcd-${ETCD_VERSION}-linux-amd64.tar.gz
+```
+#### 2. Unzip and install
+```
+tar -xvf etcd-${ETCD_VERSION}-linux-amd64.tar.gz
+cd etcd-${ETCD_VERSION}-linux-amd64
+# copy the binary to system
+sudo cp etcd etcdctl /usr/local/bin/
+```
+#### 3. Verify installation
+```
+etcd --version
+etcdctl version
+```
+
+
+## Install YR-DataSystem
+#### Install from pip (recommended):
+
+```
+pip install yr-datasystem
+```
+
+#### Or install from source:
+
+- Refer to the yr-datasystem documentation [here](https://gitee.com/openeuler/yuanrong-datasystem)
+
+
+
+# Deployment
+## Deploy etcd
+> Note: this is the minimal example to deploy etcd, more can be found at the [etcd official site](https://etcd.io/docs/current/op-guide/clustering/).
+
+#### Deploy a single node etcd cluster at port 2379:
+```
+etcd \
+ --name etcd-single \
+ --data-dir /tmp/etcd-data \
+ --listen-client-urls http://0.0.0.0:2379 \
+ --advertise-client-urls http://0.0.0.0:2379 \
+ --listen-peer-urls http://0.0.0.0:2380 \
+ --initial-advertise-peer-urls http://0.0.0.0:2380 \
+ --initial-cluster etcd-single=http://0.0.0.0:2380
+```
+
+
+#### Parameters:
+- --name:cluster name
+- --data-dir:directory to store data
+- --listen-client-urls:address to listen from clients (0.0.0.0 allows access from any IP address)
+- --advertise-client-urls:address advertised to clients
+- --listen-peer-urls:address to listen from other nodes in the cluster
+- --initial-advertise-peer-urls:address advertised to other nodes in the cluster
+- --initial-cluster:initial nodes in the cluster (format: name1=peer_url1,name2=peer_url2,...)
+
+#### Try to access the etcd cluster with the `etcdctl` command:
+```
+etcdctl --endpoints "127.0.0.1:2379" put key "value"
+etcdctl --endpoints "127.0.0.1:2379" get key
+```
+etcd cluster is successfully deployed if the commands work good.
+
+## Deploy YR-DataSystem
+#### Deploy a single node yr-datasystem cluster with the minimum config:
+```
+dscli start -w --worker_address "127.0.0.1:31501" --etcd_address "127.0.0.1:2379"
+# [INFO] [ OK ] Start worker service @ 127.0.0.1:31501 success, PID: 38100
+```
+yr-datasystem is deployed successful as you see the `[ OK ]` output.
+
+#### To safely stop and clean the yr-datasystem processes, run the command:
+```
+dscli stop -w --worker_address "127.0.0.1:31501"
+```
+#### Please refer to the [yr-datasystem gitee repo](https://gitee.com/openeuler/yuanrong-datasystem) for more information.
+
+# Run disaggregated prefill with vLLM v1
+
+> Note: an example script for 1P1D disaggregated prefill is available at: *vllm-ascend/tests/e2e/pd_disaggregate/yuanrong/test_yuanrong_connector.py*
+
+#### 1. Populate the yr-datasystem worker address with environment variable:
+
+`export DS_WORKER_ADDR=127.0.0.1:31501`
+
+YuanRongConnector will read the yr-datasystem address from this environment variable
+
+#### 2. Start two vLLM instances with YuanRongConnector as the backend to form a 1P1D disaggregated cluster:
+```
+export VLLM_USE_V1=True
+
+# start a prefill instance on localhost:8100
+ASCEND_RT_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \
+ --port 8100 \
+ --max-num-batched-tokens 45000 \
+ --gpu-memory-utilization 0.8 \
+ --trust-remote-code \
+ --enforce-eager \
+ --kv-transfer-config \
+ '{"kv_connector":"YuanRongConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
+
+# start a decode instance on localhost:8200
+ASCEND_RT_VISIBLE_DEVICES=1 vllm serve Qwen/Qwen2.5-7B-Instruct \
+ --port 8200 \
+ --max-num-batched-tokens 45000 \
+ --gpu-memory-utilization 0.8 \
+ --trust-remote-code \
+ --enforce-eager \
+ --kv-transfer-config \
+ '{"kv_connector":"YuanRongConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
+```
+
+#### 3. Start a proxy server to serve and route HTTP requests:
+```
+python vllm-ascend/tests/e2e/pd_disaggregate/yuanrong/simple_pd_proxy_server.py --prefiller-port 8100 --decoder-port 8200
+```
+
+#### 4. Send HTTP requests to the proxy server:
+```
+curl -X POST -s http://localhost:8000/v1/completions \
+-H "Content-Type: application/json" \
+-d '{
+"model": "Qwen/Qwen2.5-7B-Instruct",
+"prompt": "who is the presiden of the united states?",
+"max_tokens": 50,
+"temperature": 0
+}'
+```
\ No newline at end of file
diff --git a/tests/e2e/pd_disaggreate/yuanrong/clean_yuanrong.sh b/tests/e2e/pd_disaggreate/yuanrong/clean_yuanrong.sh
new file mode 100644
index 0000000..7941a45
--- /dev/null
+++ b/tests/e2e/pd_disaggreate/yuanrong/clean_yuanrong.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+HOST_IP=$1
+WORKER_PORT=$2
+
+dscli stop \
+ --worker_address ${HOST_IP}:${WORKER_PORT}
+
+pkill etcd
\ No newline at end of file
diff --git a/tests/e2e/pd_disaggreate/yuanrong/run_pd_instances.sh b/tests/e2e/pd_disaggreate/yuanrong/run_pd_instances.sh
new file mode 100644
index 0000000..c2e18b5
--- /dev/null
+++ b/tests/e2e/pd_disaggreate/yuanrong/run_pd_instances.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+
+MODEL_NAME=$1
+HOST_IP=$2
+PREFILL_PORT=$3
+DECODE_PORT=$4
+
+if python -c "import datasystem" &> /dev/null; then
+ echo "yr-datasystem is already installed"
+else
+ echo "Install yr-datasystem ..."
+ python -m pip install yr-datasystem
+fi
+
+wait_for_server() {
+ local port=$1
+ timeout 1200 bash -c "
+ until curl -s ${HOST_IP}:${port}/v1/completions > /dev/null; do
+ sleep 1
+ done" && return 0 || return 1
+}
+
+ASCEND_RT_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
+ --host ${HOST_IP} \
+ --port ${PREFILL_PORT} \
+ --max-num-batched-tokens 45000 \
+ --gpu-memory-utilization 0.8 \
+ --trust-remote-code \
+ --enforce-eager \
+ --kv-transfer-config \
+ '{"kv_connector":"YuanRongConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
+
+ASCEND_RT_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
+ --host ${HOST_IP} \
+ --port ${DECODE_PORT} \
+ --max-num-batched-tokens 45000 \
+ --gpu-memory-utilization 0.8 \
+ --trust-remote-code \
+ --enforce-eager \
+ --kv-transfer-config \
+ '{"kv_connector":"YuanRongConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
+
+wait_for_server ${PREFILL_PORT}
+wait_for_server ${DECODE_PORT}
diff --git a/tests/e2e/pd_disaggreate/yuanrong/run_proxy_server.sh b/tests/e2e/pd_disaggreate/yuanrong/run_proxy_server.sh
new file mode 100644
index 0000000..879f486
--- /dev/null
+++ b/tests/e2e/pd_disaggreate/yuanrong/run_proxy_server.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+PROXY_SERVER_SCRIPT=$1
+HOST=$2
+PORT=$3
+PREFILL_PORT=$4
+DECODE_PORT=$5
+
+python ${PROXY_SERVER_SCRIPT} \
+ --host ${HOST} \
+ --port ${PORT} \
+ --prefiller-host ${HOST} \
+ --prefiller-port ${PREFILL_PORT} \
+ --decoder-host ${HOST} \
+ --decoder-port ${DECODE_PORT} &
\ No newline at end of file
diff --git a/tests/e2e/pd_disaggreate/yuanrong/run_yuanrong.sh b/tests/e2e/pd_disaggreate/yuanrong/run_yuanrong.sh
new file mode 100644
index 0000000..8f4e3d7
--- /dev/null
+++ b/tests/e2e/pd_disaggreate/yuanrong/run_yuanrong.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+HOST_IP=$1
+WORKER_PORT=$2
+ETCD_PORT=$3
+
+MASTER_PORT=`expr ${WORKER_PORT} + 1`
+ETCD_PEER_PORT=`expr ${ETCD_PORT} + 1`
+
+etcd \
+ --name etcd-yuanrong \
+ --data-dir /tmp/etcd-yuanrong \
+ --listen-client-urls http://${HOST_IP}:${ETCD_PORT} \
+ --advertise-client-urls http://${HOST_IP}:${ETCD_PORT} \
+ --listen-peer-urls http://${HOST_IP}:${ETCD_PEER_PORT} \
+ --initial-advertise-peer-urls http://${HOST_IP}:${ETCD_PEER_PORT} \
+ --initial-cluster etcd-yuanrong=http://${HOST_IP}:${ETCD_PEER_PORT} &
+
+
+dscli start \
+ -w \
+ --worker_address ${HOST_IP}:${WORKER_PORT} \
+ --master_address ${HOST_IP}:${MASTER_PORT} \
+ --etcd_address ${HOST_IP}:${ETCD_PORT} &
\ No newline at end of file
diff --git a/tests/e2e/pd_disaggreate/yuanrong/simple_pd_proxy_server.py b/tests/e2e/pd_disaggreate/yuanrong/simple_pd_proxy_server.py
new file mode 100644
index 0000000..c6b957c
--- /dev/null
+++ b/tests/e2e/pd_disaggreate/yuanrong/simple_pd_proxy_server.py
@@ -0,0 +1,212 @@
+import argparse
+import os
+import time
+from contextlib import asynccontextmanager
+from uuid import uuid4
+
+import httpx
+import numpy as np
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """
+ Lifespan context manager to handle startup and shutdown events.
+ """
+ # Startup: Initialize clients
+ prefiller_base_url = (
+ f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1")
+ decoder_base_url = (
+ f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1")
+
+ app.state.prefill_client = httpx.AsyncClient(timeout=None,
+ base_url=prefiller_base_url)
+ app.state.decode_client = httpx.AsyncClient(timeout=None,
+ base_url=decoder_base_url)
+
+ yield
+
+ # Shutdown: Close clients
+ await app.state.prefill_client.aclose()
+ await app.state.decode_client.aclose()
+
+
+# Update FastAPI app initialization to use lifespan
+app = FastAPI(lifespan=lifespan)
+
+
+class StatsCalculator:
+
+ def __init__(self):
+ self._stats = []
+ self._last_log_time = time.time()
+
+ def add(self, value):
+ self._stats.append(value)
+ if time.time() - self._last_log_time > 5:
+ self._log_stats()
+ self._last_log_time = time.time()
+
+ def _log_stats(self):
+ # Print average, median, and 99th percentile
+ np_arr = np.array(self._stats)
+ output_str = (
+ f"\nNum requests: {len(self._stats)}" +
+ "\nPrefill node TTFT stats:" +
+ f"\n - Average (ms): {np.mean(np_arr)}" +
+ f"\n - Median (ms): {np.median(np_arr)}" +
+ f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n")
+ print(
+ "===============================",
+ output_str,
+ "===============================",
+ )
+
+
+stats_calculator = StatsCalculator()
+counter = 0
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--prefiller-host", type=str, default="localhost")
+ parser.add_argument("--prefiller-port", type=int, default=8100)
+ parser.add_argument("--decoder-host", type=str, default="localhost")
+ parser.add_argument("--decoder-port", type=int, default=8200)
+ args = parser.parse_args()
+ return args
+
+
+# Initialize variables to hold the persistent clients
+app.state.prefill_client = None
+app.state.decode_client = None
+
+
+async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
+ req_data: dict, request_id: str):
+ """
+ Send a request to a service using a persistent client.
+ """
+ req_data = req_data.copy()
+ req_data["max_tokens"] = 1
+ if "max_completion_tokens" in req_data:
+ req_data["max_completion_tokens"] = 1
+
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
+ "X-Request-Id": request_id
+ }
+ response = await client.post(endpoint, json=req_data, headers=headers)
+ response.raise_for_status()
+ return response
+
+
+async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
+ req_data: dict, request_id: str):
+ """
+ Asynchronously stream the response from a service using a persistent client.
+ """
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
+ "X-Request-Id": request_id
+ }
+ async with client.stream("POST", endpoint, json=req_data,
+ headers=headers) as response:
+ response.raise_for_status()
+ async for chunk in response.aiter_bytes():
+ yield chunk
+
+
+@app.post("/v1/completions")
+async def handle_completions(request: Request):
+ global counter, stats_calculator
+ counter += 1
+
+ st = time.time()
+ try:
+ req_data = await request.json()
+ request_id = str(uuid4())
+
+ # Send request to prefill service, ignore the response
+ await send_request_to_service(app.state.prefill_client, "/completions",
+ req_data, request_id)
+
+ et = time.time()
+ stats_calculator.add(et - st)
+
+ # Stream response from decode service
+ async def generate_stream():
+ async for chunk in stream_service_response(app.state.decode_client,
+ "/completions",
+ req_data, request_id):
+ yield chunk
+
+ return StreamingResponse(generate_stream(),
+ media_type="text/event-stream")
+
+ except Exception as e:
+ import sys
+ import traceback
+
+ exc_info = sys.exc_info()
+ print(
+ "Error occurred in disagg prefill proxy server - completions endpoint"
+ )
+ print(e)
+ print("".join(traceback.format_exception(*exc_info)))
+ raise
+
+
+@app.post("/v1/chat/completions")
+async def handle_chat_completions(request: Request):
+ global counter, stats_calculator
+ counter += 1
+
+ st = time.time()
+ try:
+ req_data = await request.json()
+ request_id = str(uuid4())
+
+ # Send request to prefill service, ignore the response
+ await send_request_to_service(app.state.prefill_client,
+ "/chat/completions", req_data,
+ request_id)
+
+ et = time.time()
+ stats_calculator.add(et - st)
+
+ # Stream response from decode service
+ async def generate_stream():
+ async for chunk in stream_service_response(app.state.decode_client,
+ "/chat/completions",
+ req_data, request_id):
+ yield chunk
+
+ return StreamingResponse(generate_stream(),
+ media_type="text/event-stream")
+
+ except Exception as e:
+ import sys
+ import traceback
+
+ exc_info = sys.exc_info()
+ print(
+ "Error occurred in disagg prefill proxy server - chat completions endpoint"
+ )
+ print(e)
+ print("".join(traceback.format_exception(*exc_info)))
+ raise
+
+
+if __name__ == "__main__":
+ global global_args
+ global_args = parse_args()
+
+ import uvicorn
+
+ uvicorn.run(app, host=global_args.host, port=global_args.port)
diff --git a/tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py b/tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py
new file mode 100644
index 0000000..aaf6da3
--- /dev/null
+++ b/tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py
@@ -0,0 +1,141 @@
+import json
+import os
+import signal
+import subprocess
+import time
+
+import pytest
+import requests
+
+HOST_IP = "127.0.0.1"
+MODEL_NAME = "Qwen/Qwen2.5-7B"
+WORKSPACE_DIR = "./tests/e2e/pd_disaggreate/yuanrong/"
+
+RUN_INSTANCES_SCRIPT = os.path.join(WORKSPACE_DIR,
+ "run_pd_with_yuanrong_connector.sh")
+RUN_PROXY_SERVER_SCRIPT = os.path.join(WORKSPACE_DIR, "run_proxy_server.sh")
+RUN_YUANRONG_SCRIPT = os.path.join(WORKSPACE_DIR, "run_yuanrong.sh")
+CLEAN_YUANRONG_SCRIPT = os.path.join(WORKSPACE_DIR, "clean_yuanrong.sh")
+PROXY_SERVER_SCRIPT = os.path.join(WORKSPACE_DIR, "simple_pd_proxy_server.py")
+PROXY_PORT = 8000
+PREFILL_PORT = 8100
+DECODE_PORT = 8200
+WORKER_PORT = 31530
+ETCD_PORT = 2411
+
+PROMPT_ANSWER = {
+ "who is the president of the united states?": "?\nDonald Trump"
+}
+RUN_INSTANCE_KEYWORDS = "vllm serve"
+RUN_PROXY_SERVER_KEYWORDS = "simple_pd_proxy_server.py"
+
+
+def start_yuanrong():
+ proc = subprocess.Popen([
+ "bash", RUN_YUANRONG_SCRIPT, f"{HOST_IP}", f"{WORKER_PORT}",
+ f"{ETCD_PORT}"
+ ])
+ proc.wait()
+
+
+def clean_yuanrong():
+ proc = subprocess.Popen(
+ ["bash", CLEAN_YUANRONG_SCRIPT, f"{HOST_IP}", f"{WORKER_PORT}"])
+ proc.wait()
+
+
+def start_instances():
+ proc = subprocess.Popen([
+ "bash", RUN_INSTANCES_SCRIPT, f"{MODEL_NAME}", f"{HOST_IP}",
+ f"{PREFILL_PORT}", f"{DECODE_PORT}"
+ ])
+ proc.wait()
+
+
+def start_proxy_server():
+ proc = subprocess.Popen([
+ "bash", RUN_PROXY_SERVER_SCRIPT, PROXY_SERVER_SCRIPT, f"{HOST_IP}",
+ f"{PROXY_PORT}", f"{PREFILL_PORT}", f"{DECODE_PORT}"
+ ])
+ proc.wait()
+
+
+def clean_instances_and_proxy_server():
+ instance_pids = get_pids_by_keyword(RUN_INSTANCE_KEYWORDS)
+ proxy_pids = get_pids_by_keyword(RUN_PROXY_SERVER_KEYWORDS)
+ for pid in proxy_pids + instance_pids:
+ pid = int(pid)
+ try:
+ os.kill(pid, signal.SIGINT)
+ except ProcessLookupError:
+ print(f"No such process with PID {pid}")
+ except PermissionError:
+ print(f"Permission denied to send SIGINT to PID {pid}")
+ except Exception as e:
+ print(f"Error: {e}")
+ time.sleep(3)
+ pid = int(pid)
+ try:
+ os.kill(pid, signal.SIGKILL)
+ except ProcessLookupError:
+ print(f"No such process with PID {pid}")
+ except PermissionError:
+ print(f"Permission denied to send SIGKILL to PID {pid}")
+ except Exception as e:
+ print(f"Error: {e}")
+
+
+def send_post_request(url, data):
+ try:
+ response = requests.post(url, json=data, timeout=10)
+ response.raise_for_status()
+ return response.text
+ except requests.exceptions.RequestException as e:
+ return f"Request failed: {e}"
+
+
+def get_pids_by_keyword(keyword):
+ try:
+ # Run 'ps aux' to get all running processes
+ result = subprocess.run(['ps', 'aux'],
+ stdout=subprocess.PIPE,
+ text=True)
+ lines = result.stdout.strip().split('\n')
+
+ matching_pids = []
+
+ for line in lines[1:]: # Skip the header line
+ if keyword in line:
+ parts = line.split()
+ pid = parts[1] # PID is the second column
+ matching_pids.append(pid)
+
+ return matching_pids
+ except Exception as e:
+ return f"error occurred trying to get PIDs of processes containing keyword {keyword}, error: {e}"
+
+
+@pytest.fixture
+def setup_and_clean_cluster():
+ start_yuanrong()
+ start_instances()
+ start_proxy_server()
+ time.sleep(3)
+ yield
+ clean_instances_and_proxy_server()
+ clean_yuanrong()
+
+
+def test_yuanrong_pd_dist(setup_and_clean_cluster):
+ proxy_url = f"http://{HOST_IP}:{PROXY_PORT}/v1/completions"
+ for prompt, answer in PROMPT_ANSWER.items():
+ data = {
+ "model": MODEL_NAME,
+ "prompt": prompt,
+ "max_tokens": 50,
+ "temperature": 0
+ }
+ response_str = send_post_request(proxy_url, data)
+ response_json = json.loads(response_str)
+ output = response_json["choices"][0]["text"]
+ assert output == answer, f"wrong response: {output}, expected: {answer}"
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
index 6e031ea..de09f5d 100644
--- a/vllm_ascend/attention/attention_v1.py
+++ b/vllm_ascend/attention/attention_v1.py
@@ -24,6 +24,7 @@ import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
+from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
@@ -435,6 +436,34 @@ class AscendAttentionBackendImpl(AttentionImpl):
ori_output[:, :, :] = output[:num_tokens, :, :]
return output.view(num_tokens, self.hidden_size)
+def wait_for_kv_layer_from_connector(layer_name: str):
+ if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
+ return
+
+ connector = get_kv_transfer_group()
+
+ forward_context: ForwardContext = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+ if attn_metadata is None:
+ return
+ connector.wait_for_layer_load(layer_name)
+
+
+def maybe_save_kv_layer_to_connector(
+ layer_name: str,
+ kv_cache_layer: List[torch.Tensor],
+):
+ if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
+ return
+
+ connector = get_kv_transfer_group()
+
+ forward_context: ForwardContext = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+ if attn_metadata is None:
+ return
+ connector.save_kv_layer(layer_name, kv_cache_layer,
+ attn_metadata)
def unified_ascend_attention_with_output(
query: torch.Tensor,
@@ -443,6 +472,7 @@ def unified_ascend_attention_with_output(
output: torch.Tensor,
layer_name: str,
) -> None:
+ wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
@@ -455,6 +485,7 @@ def unified_ascend_attention_with_output(
attn_metadata,
output,
trace_flag=False)
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return
diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py
index 4a4131e..6f91566 100644
--- a/vllm_ascend/core/schedule_config.py
+++ b/vllm_ascend/core/schedule_config.py
@@ -59,9 +59,9 @@ class AscendSchedulerConfig(SchedulerConfig):
raise NotImplementedError(
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
)
- if self.is_multimodal_model:
- raise NotImplementedError(
- "currently AscendScheduler only supports LLM models.")
+ #if self.is_multimodal_model:
+ # raise NotImplementedError(
+ # "currently AscendScheduler only supports LLM models.")
if self.num_scheduler_steps > 1:
raise NotImplementedError(
"currently AscendScheduler doesn't support multi-step.")
diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py
index dfdc9aa..ce8bd66 100644
--- a/vllm_ascend/core/scheduler.py
+++ b/vllm_ascend/core/scheduler.py
@@ -62,6 +62,9 @@ class AscendScheduler(Scheduler):
req_to_new_block_ids: dict[str, list[int]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
+ # Encoder-related.
+ scheduled_encoder_inputs: dict[str, list[int]] = {}
+ encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
@@ -129,6 +132,9 @@ class AscendScheduler(Scheduler):
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
+ encoder_inputs_to_schedule = None
+ new_encoder_budget = encoder_budget
+
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
@@ -165,6 +171,17 @@ class AscendScheduler(Scheduler):
continue
assert num_new_tokens > 0
blocks = new_computed_blocks.blocks[0]
+
+ # Schedule encoder inputs.
+ if request.has_encoder_inputs:
+ (encoder_inputs_to_schedule, num_new_tokens,
+ new_encoder_budget
+ ) = self._try_schedule_encoder_inputs(
+ request, num_computed_tokens, num_new_tokens,
+ encoder_budget)
+ if num_new_tokens == 0:
+ # The request cannot be scheduled.
+ break
watermark = getattr(self.scheduler_config, "watermark", 0.01)
if not self._check_watermark_for_prefill(request, num_new_tokens,
@@ -224,9 +241,17 @@ class AscendScheduler(Scheduler):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
- # Count the number of prefix cached tokens.
+ # Count the number of prifix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
+ # Encoder-related.
+ if encoder_inputs_to_schedule:
+ scheduled_encoder_inputs[request.request_id] = (
+ encoder_inputs_to_schedule)
+ # Allocate the encoder cache.
+ for i in encoder_inputs_to_schedule:
+ self.encoder_cache_manager.allocate(request, i)
+ encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
@@ -252,6 +277,16 @@ class AscendScheduler(Scheduler):
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
+
+ # Schedule encoder inputs.
+ encoder_inputs_to_schedule = None
+ new_encoder_budget = encoder_budget
+ if request.has_encoder_inputs:
+ (encoder_inputs_to_schedule, num_new_tokens,
+ new_encoder_budget) = self._try_schedule_encoder_inputs(
+ request, request.num_computed_tokens, num_new_tokens,
+ encoder_budget)
+
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request and (
@@ -274,10 +309,15 @@ class AscendScheduler(Scheduler):
req_index += 1
continue
+ num_draft_tokens = max(
+ num_new_tokens + request.num_computed_tokens -
+ request.num_tokens, 0)
+
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
+ num_new_computed_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
@@ -323,7 +363,16 @@ class AscendScheduler(Scheduler):
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
-
+
+ # Encoder-related.
+ if encoder_inputs_to_schedule:
+ scheduled_encoder_inputs[request.request_id] = (
+ encoder_inputs_to_schedule)
+ # Allocate the encoder cache.
+ for i in encoder_inputs_to_schedule:
+ self.encoder_cache_manager.allocate(request, i)
+ encoder_budget = new_encoder_budget
+
# Record scheduled LoRA requests.
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
@@ -364,7 +413,7 @@ class AscendScheduler(Scheduler):
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
- scheduled_encoder_inputs={},
+ scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
@@ -485,3 +534,40 @@ class AscendScheduler(Scheduler):
return super().update_from_output(scheduler_output,
model_runner_output)
+
+ def _update_waiting_for_remote_kv(self, request: Request) -> bool:
+ """
+ KV Connector: check if the request_id is finished_recving.
+
+ The finished_recving_kv_req_ids list is populated
+ on the previous steps()'s update_from_output based
+ on the worker side connector.
+
+ When the kv transfer is ready, we cache the blocks
+ and the request state will be moved back to WAITING from
+ WAITING_FOR_REMOTE_KV.
+ """
+ assert self.connector is not None
+ if request.request_id not in self.finished_recving_kv_req_ids:
+ return False
+
+ # Now that the blocks are ready, actually cache them.
+ (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
+ num_computed_tokens = len(block_ids) * self.block_size
+ # Handle the case where num request tokens less then one block.
+ num_computed_tokens = min(num_computed_tokens, request.num_tokens)
+ if num_computed_tokens == request.num_tokens:
+ num_computed_tokens -= 1
+
+ # This will cache the blocks if caching is enabled.
+ # Note: vllm fix this in main branch, but still have issue on v0.9.1, so we just adopt the
+ # change on 0.9.1 and without cherry-pick this back to main branch on vllm-ascend
+ if self.kv_cache_manager.enable_caching:
+ self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
+
+ # Update the request state for scheduling.
+ request.num_computed_tokens = num_computed_tokens
+
+ # Return that we are ready.
+ self.finished_recving_kv_req_ids.remove(request.request_id)
+ return True
diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py
index ebe8694..ee9f216 100644
--- a/vllm_ascend/distributed/__init__.py
+++ b/vllm_ascend/distributed/__init__.py
@@ -22,3 +22,7 @@ KVConnectorFactory.register_connector(
"LLMDataDistCMgrConnector",
"vllm_ascend.distributed.llmdatadist_c_mgr_connector",
"LLMDataDistCMgrConnector")
+
+KVConnectorFactory.register_connector(
+ "YuanRongConnector", "vllm_ascend.distributed.yuanrong_connector",
+ "YuanRongConnector")
diff --git a/vllm_ascend/distributed/yuanrong_connector.py b/vllm_ascend/distributed/yuanrong_connector.py
new file mode 100644
index 0000000..9b7407a
--- /dev/null
+++ b/vllm_ascend/distributed/yuanrong_connector.py
@@ -0,0 +1,828 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import enum
+import hashlib
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, List, Optional, Any
+import threading
+from collections import defaultdict
+import asyncio
+
+import numpy
+import torch
+from vllm.config import VllmConfig
+from vllm.distributed.kv_transfer.kv_connector.v1.base import (
+ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
+from vllm.logger import init_logger
+from vllm.v1.attention.backends.mla.common import MLACommonMetadata
+from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.distributed.parallel_state import get_tp_group
+
+from datasystem import DsTensorClient, Future
+
+ENABLE_PREFIX_CACHING = int(os.environ.get("USING_PREFIX_CONNECTOR", 1))
+FUTURE_TIMEOUT = int(os.getenv("FUTURE_TIMEOUT", 10000))
+SYNC_FUTURE_TIMEOUT = int(os.getenv("SYNC_FUTURE_TIMEOUT", 1))
+SLEEP_TIMEOUT = 0.005
+
+if TYPE_CHECKING:
+ from vllm.attention.backends.abstract import AttentionMetadata
+ from vllm.forward_context import ForwardContext
+ from vllm.v1.request import Request
+
+logger = init_logger(f"vllm.{__name__}")
+
+
+class RequestStatus(enum.IntEnum):
+ WAITING = enum.auto()
+ TIMEOUT = enum.auto()
+ FINISHED = enum.auto()
+
+
+@dataclass
+class RequestTracker:
+ request_id: str
+ token_ids: torch.Tensor
+ block_ids: list[int]
+ num_scheduled_tokens: int
+
+ @staticmethod
+ def from_new_request(request_id, token_ids, block_ids, num_scheduled_tokens) -> "RequestTracker":
+ """
+ Create the request tracker from a new request.
+ """
+ return RequestTracker(
+ request_id=request_id,
+ token_ids=token_ids,
+ block_ids=block_ids,
+ num_scheduled_tokens=num_scheduled_tokens
+ )
+
+ def update(
+ self,
+ block_ids,
+ num_external_scheduled_tokens
+ ) -> None:
+ """
+ Update the request tracker when a running request is
+ scheduled again
+ """
+ self.block_ids[0].extend(block_ids[0])
+ self.num_scheduled_tokens += num_external_scheduled_tokens
+
+
+@dataclass
+class ReqMeta:
+ request_id: str
+ token_ids: torch.Tensor
+ block_ids: list[int]
+ request_rank: int
+ skip_block_num: int
+ ds_cached_block_num: int
+ need_save: bool
+
+ @staticmethod
+ def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
+ block_size: int, request_rank: int, skip_block_num: int, ds_cached_block_num: int, need_save: bool) \
+ -> "ReqMeta":
+ """make request meta"""
+ valid_num_tokens = align_to_block_size(len(token_ids), block_size)
+ valid_block_ids = valid_num_tokens // block_size
+ return ReqMeta(
+ request_id=request_id,
+ token_ids=numpy.array(token_ids),
+ block_ids=block_ids[0][:valid_block_ids],
+ request_rank=request_rank,
+ skip_block_num=skip_block_num,
+ ds_cached_block_num=ds_cached_block_num,
+ need_save=need_save
+ )
+
+
+@dataclass
+class YuanRongConnectorMetadata(KVConnectorMetadata):
+ requests: list[ReqMeta]
+
+ def __init__(self, tp_size, block_size):
+ self.requests = []
+ self.tp_size = tp_size
+ self.request_rank = 0
+ self._block_size = block_size
+
+ def add_request(
+ self,
+ request_id: str,
+ token_ids: list[int],
+ block_ids: list[int],
+ skip_block_num: int,
+ ds_cached_block_num: int,
+ need_save: bool = True
+ ) -> None:
+ """add request meta"""
+ request_rank = self.request_rank % self.tp_size
+ self.requests.append(
+ ReqMeta.make_meta(request_id, token_ids, block_ids, self._block_size, request_rank, skip_block_num,
+ ds_cached_block_num, need_save))
+ self.request_rank = request_rank + 1
+
+
+@dataclass
+class ReqState:
+ """Per-request state for tracking async transfers."""
+ num_pending: int = -1
+ finished: bool = False
+
+
+class AsyncHandler:
+ """Manage async saving/loading in separate thread."""
+
+ def __init__(self, role, task_list):
+ self._async_save_reqs = defaultdict[str, ReqState](ReqState)
+ self._async_load_reqs = defaultdict[str, ReqState](ReqState)
+ self._is_producer = role
+ self._finished_save_reqs = asyncio.Queue()
+ self._finished_load_reqs = asyncio.Queue()
+ self._future_save_list = asyncio.Queue()
+ self._future_load_list = asyncio.Queue()
+ if self._is_producer or ENABLE_PREFIX_CACHING:
+ task_list.append(asyncio.get_event_loop().create_task(self.get_save_futures_async()))
+ if not self._is_producer or ENABLE_PREFIX_CACHING:
+ task_list.append(asyncio.get_event_loop().create_task(self.get_load_futures_async()))
+
+ async def get_save_futures_async(self):
+ """async get save futures"""
+ while True:
+ try:
+ save_future_len = self._future_save_list.qsize()
+ for _ in range(save_future_len):
+ request_id, future = self._future_save_list.get_nowait()
+ res = get_future(future)
+ req_state = self._async_save_reqs[request_id]
+ if res == RequestStatus.FINISHED:
+ logger.info(f"request: {request_id} is finished")
+ req_state.num_pending -= 1
+ if req_state.finished and not req_state.num_pending:
+ self._finished_save_reqs.put_nowait(request_id)
+ del self._async_save_reqs[request_id]
+ elif res == RequestStatus.WAITING or not req_state.finished:
+ self._future_save_list.put_nowait((request_id, future))
+ else:
+ logger.error(f"request:{request_id} get save future timeout, res:{res}")
+ self._finished_save_reqs.put_nowait(request_id)
+ del self._async_save_reqs[request_id]
+ await asyncio.sleep(SLEEP_TIMEOUT)
+ except Exception as e:
+ logger.error(f"get_futures_async fail, error:{e}")
+
+ async def get_load_futures_async(self):
+ """async get load futures"""
+ while True:
+ try:
+ load_future_len = self._future_load_list.qsize()
+ for _ in range(load_future_len):
+ request_id, future = self._future_load_list.get_nowait()
+ res = get_future(future)
+ req_state = self._async_load_reqs[request_id]
+ if res == RequestStatus.FINISHED:
+ logger.info(f"request: {request_id} is finished")
+ req_state.num_pending -= 1
+ if not req_state.num_pending:
+ self._finished_load_reqs.put_nowait(request_id)
+ del self._async_load_reqs[request_id]
+ elif res == RequestStatus.WAITING:
+ self._future_load_list.put_nowait((request_id, future))
+ else:
+ logger.error(f"request:{request_id} get load future timeout, res:{res}")
+ self._finished_load_reqs.put_nowait(request_id)
+ del self._async_load_reqs[request_id]
+ await asyncio.sleep(SLEEP_TIMEOUT)
+ except Exception as e:
+ logger.error(f"get_futures_async fail, error:{e}")
+
+ def add_save_request(self, request: ReqMeta, future_num: int) -> None:
+ """add save request future"""
+ self._async_save_reqs[request.request_id].num_pending = future_num
+
+ def add_load_request(self, request: ReqMeta, future_num: int) -> None:
+ """add load reqeust future"""
+ self._async_load_reqs[request.request_id].num_pending = future_num
+
+ def add_save_future(self, request: ReqMeta, future: Future) -> None:
+ """add save reqeust future"""
+ self._future_save_list.put_nowait((request.request_id, future))
+
+ def add_load_future(self, request: ReqMeta, future: Future) -> None:
+ """add load request future"""
+ self._future_load_list.put_nowait((request.request_id, future))
+
+ def get_save_finished(self, finished_request_ids: set[str]) -> Optional[set[str]]:
+ """Finished saving request ids."""
+ finished_reqs = set()
+ for req_id in finished_request_ids:
+ req_state = self._async_save_reqs[req_id]
+ if req_state:
+ req_state.finished = True
+ if not req_state.num_pending:
+ finished_reqs.add(req_id)
+ del self._async_save_reqs[req_id]
+
+ while not self._finished_save_reqs.empty():
+ finished_reqs.add(self._finished_save_reqs.get_nowait())
+ if len(finished_reqs) != 0:
+ logger.debug(f"get_finished, finished_reqs:{finished_reqs}, length:{len(finished_reqs)}")
+ else:
+ finished_reqs = None
+ return finished_reqs
+
+ def get_load_finished(self) -> set[str]:
+ """Finished saving request ids."""
+ finished_reqs = set()
+ while not self._finished_load_reqs.empty():
+ finished_reqs.add(self._finished_load_reqs.get_nowait())
+ if len(finished_reqs) != 0:
+ logger.debug(f"get_finished, finished_reqs:{finished_reqs}, length:{len(finished_reqs)}")
+ else:
+ finished_reqs = None
+ return finished_reqs
+
+
+class YuanRongConnector(KVConnectorBase_V1):
+
+ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
+ super().__init__(vllm_config=vllm_config, role=role)
+ self._block_size = vllm_config.cache_config.block_size
+ self._requests_need_load: dict[str, Request] = {}
+ self.config = vllm_config.kv_transfer_config
+ self.is_producer = self.config.is_kv_producer
+ self.do_async_save = int(os.getenv("ASYNC_SAVE", 1))
+ self.layer_name_list = []
+ self.kv_caches = []
+ self.key_caches = []
+ self.value_caches = []
+ self._skip_blocks: dict[str, int] = {}
+ self._ds_cached_blocks: dict[str, int] = {}
+ self._delay_save = {}
+ self._load_request_queue = asyncio.Queue()
+ self._save_request_queue = asyncio.Queue()
+ self.task_list = []
+ self.is_ms_non_mla_type = False
+ self.is_ms_mla = False
+ self.is_mla = False
+ self._async_handler = None
+
+ thread_num = int(os.getenv("THREAD_NUM", 32))
+ conn_timeout_ms = int(os.getenv("CONN_TIMEOUT_MS", 6000))
+ self.tp_size = vllm_config.parallel_config.tensor_parallel_size
+ ds_worker_addr = os.getenv("DS_WORKER_ADDR", "172.17.0.4:9000")
+ ip_port = ds_worker_addr.split(":")
+ ip = ip_port[0]
+ port = int(ip_port[1])
+
+ self.device = self.tp_rank = 0
+ if role == KVConnectorRole.WORKER:
+ self.tp_rank = get_tp_group().rank_in_group
+ self.tp_group = get_tp_group()
+ self.kvc_store = DsTensorClient(ip, port, self.device)
+ self.kvc_store.init()
+ if self.do_async_save:
+ self.loop = asyncio.get_event_loop()
+ self._async_handler = AsyncHandler(self.is_producer, self.task_list)
+ if ENABLE_PREFIX_CACHING or not self.is_producer:
+ self.task_list.append(self.loop.create_task(self.consumer_request_task()))
+
+ if ENABLE_PREFIX_CACHING or self.is_producer:
+ self.task_list.append(self.loop.create_task(self.producer_request_task()))
+
+ thread = threading.Thread(target=self.start_event_loop, daemon=True)
+ thread.start()
+ elif ENABLE_PREFIX_CACHING:
+ thread_num = 1
+ self.kvc_store = DsTensorClient(ip, port, self.device)
+ self.kvc_store.init()
+ else:
+ self.tp_group = None
+ logger.info(f"init datasystem ip = {ip}, port = {port}, device_id = {self.device}")
+
+ def start_event_loop(self):
+ """start event loop"""
+ current_thread = threading.current_thread()
+ logger.info(f"start_event_loop: {current_thread.ident}")
+ self.loop.run_until_complete(asyncio.gather(*self.task_list))
+ self.loop.close()
+
+ async def producer_request_task(self):
+ """consumer request task"""
+ while True:
+ try:
+ save_request_len = self._save_request_queue.qsize()
+ for _ in range(save_request_len):
+ request = self._save_request_queue.get_nowait()
+ self.do_save_request(request)
+ await asyncio.sleep(SLEEP_TIMEOUT)
+ except Exception as e:
+ logger.error(f"producer_request_task fail, error:{e}")
+ self._save_request_queue.put_nowait(request)
+ await asyncio.sleep(SLEEP_TIMEOUT)
+
+ async def consumer_request_task(self):
+ """consumer request task"""
+ while True:
+ try:
+ load_request_len = self._load_request_queue.qsize()
+ for _ in range(load_request_len):
+ request = self._load_request_queue.get_nowait()
+ self.do_load_kv(request)
+ await asyncio.sleep(SLEEP_TIMEOUT)
+ except Exception as e:
+ logger.error(f"consumer_request_task fail, error:{e}")
+ self._load_request_queue.put_nowait(request)
+ await asyncio.sleep(SLEEP_TIMEOUT)
+
+ def generate_kv_cache_token_key(
+ self,
+ request: ReqMeta,
+ block_start_index: int,
+ block_end_index: int
+ ) -> List[str]:
+ """
+ generate kv_cache token key.
+ """
+ if not self.is_mla:
+ external_key = "-" + str(self.tp_rank)
+ else:
+ external_key = "-0"
+
+ return generate_hash_sha256(block_start_index, block_end_index, request.token_ids,
+ self._block_size, external_key)
+
+ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
+ """
+ Start loading the KV cache from the connector buffer to vLLM's paged KV buffer.
+
+ Args:
+ forward_context (ForwardContext): the forward context.
+ **kwargs: additional arguments for the load operation
+
+ Note:
+ The number of elements in kv_caches and layer_names should be
+ the same.
+ """
+ # effective only when prefix cache is disabled and the role is producer.
+ if self.is_producer and not ENABLE_PREFIX_CACHING:
+ return
+
+ metadata: KVConnectorMetadata = self._get_connector_metadata()
+ if len(metadata.requests) == 0:
+ return
+
+ if len(self.kv_caches) == 0:
+ self._init_kv_caches_from_forward_context(forward_context)
+
+ for request in metadata.requests:
+ if self._async_handler is not None:
+ self._load_request_queue.put_nowait(request)
+ else:
+ self.do_load_kv(request)
+
+ def get_finished(
+ self, finished_req_ids: set[str]
+ ) -> tuple[Optional[set[str]], Optional[set[str]]]:
+ """Finished (saving, loading) request ids."""
+ finished_saved_req, finished_loaded_req = None, None
+ if self._async_handler is not None:
+ if self.is_producer or ENABLE_PREFIX_CACHING:
+ finished_saved_req = self._async_handler.get_save_finished(finished_req_ids)
+
+ if not self.is_producer or ENABLE_PREFIX_CACHING:
+ finished_loaded_req = self._async_handler.get_load_finished()
+
+ return finished_saved_req, finished_loaded_req
+ return None, None
+
+ def get_sending_count(self):
+ """
+ Return count of finished sending requests aggregated.
+ For mla model, just save kvc for tp rank = 0
+ """
+ if self.is_mla:
+ return 1
+ return self.tp_size
+
+ def do_load_kv(self, request) -> None:
+ """
+ Start loading the KV cache from the connector buffer to vLLM's paged KV buffer.
+
+ Note:
+ The number of elements in kv_caches and layer_names should be the same.
+ """
+ ds_cached_block_num = request.ds_cached_block_num
+ skip_block_num = request.skip_block_num
+ logger.debug(f"request:{request.request_id}, ds_cached_block_num: {ds_cached_block_num}, "
+ f"skip_block_num: {skip_block_num}")
+ if ds_cached_block_num == 0:
+ return
+ key_list = self.generate_kv_cache_token_key(request, skip_block_num, ds_cached_block_num)
+ block_id_list = request.block_ids
+ if not block_id_list or not key_list:
+ return
+ if not self.is_mla:
+ value_cache_key_list = [key + "-value" for key in key_list]
+ if len(key_list) != len(block_id_list):
+ logger.error(f"mget_tensors_h2d fail, request.request_id:{request.request_id}.")
+
+ future = self.kvc_store.mget_page_attn_blockwise_h2d(key_list, self.key_caches, block_id_list)
+ future_1 = self.kvc_store.mget_page_attn_blockwise_h2d(value_cache_key_list, self.value_caches,
+ block_id_list)
+ if not self.do_async_save:
+ get_future(future, SYNC_FUTURE_TIMEOUT)
+ get_future(future_1, SYNC_FUTURE_TIMEOUT)
+ else:
+ self._async_handler.add_load_request(request, 2)
+ self._async_handler.add_load_future(request, future)
+ self._async_handler.add_load_future(request, future_1)
+ logger.debug(f"mget_tensors_h2d success, request.request_id:{request.request_id}, "
+ f"key_list length:{len(key_list)}")
+ return
+
+ future = self.kvc_store.mget_page_attn_blockwise_h2d(key_list, self.kv_caches, block_id_list)
+ if not self.do_async_save:
+ get_future(future, SYNC_FUTURE_TIMEOUT)
+ else:
+ self._async_handler.add_load_request(request, 1)
+ self._async_handler.add_load_future(request, future)
+ logger.debug(f"mget_tensors_h2d success, request.request_id:{request.request_id}, "
+ f"key_list length:{len(key_list)}")
+
+ def wait_for_layer_load(self, layer_name: str) -> None:
+ """
+ wait_for_layer_load
+ """
+ return
+
+ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
+ attn_metadata: "AttentionMetadata", **kwargs) -> None:
+ """
+ save_kv_layer
+ """
+ if not ENABLE_PREFIX_CACHING and not self.is_producer:
+ return
+
+ if layer_name not in self.layer_name_list:
+ self.layer_name_list.append(layer_name)
+ self.is_ms_non_mla_type = isinstance(kv_layer, tuple) and len(kv_layer) == 2
+ self.is_ms_mla = os.getenv("vLLM_MODEL_BACKEND", None) == "MindFormers" and not self.is_ms_non_mla_type
+ self.is_mla = isinstance(attn_metadata, MLACommonMetadata) or self.is_ms_mla
+ if self.is_mla:
+ self.kv_caches.append(kv_layer)
+ else:
+ self.key_caches.append(kv_layer[0])
+ self.value_caches.append(kv_layer[1])
+
+ def do_save_request(self, request) -> None:
+ """
+ Start saving the KV cache of the layer from vLLM's paged buffer to the connector.
+ """
+ logger.debug(f"do_save_request, request:{request}")
+ if not self.is_producer or not request.need_save:
+ return
+
+ if self.is_mla and self.tp_rank != request.request_rank:
+ return
+
+ if not request.block_ids:
+ return
+
+ token_key_list = self.generate_kv_cache_token_key(request, 0, len(request.block_ids))
+ if not self.is_mla:
+ value_cache_key_list = [key + "-value" for key in token_key_list]
+ future = self.kvc_store.mset_page_attn_blockwise_d2h(token_key_list, self.key_caches, request.block_ids)
+ future_1 = self.kvc_store.mset_page_attn_blockwise_d2h(value_cache_key_list, self.value_caches,
+ request.block_ids)
+ if not self.do_async_save:
+ get_future(future, SYNC_FUTURE_TIMEOUT)
+ get_future(future_1, SYNC_FUTURE_TIMEOUT)
+ else:
+ self._async_handler.add_save_request(request, 2)
+ self._async_handler.add_save_future(request, future)
+ self._async_handler.add_save_future(request, future_1)
+ logger.debug(f"mset_tensors_d2h success, request.request_id:{request.request_id}, "
+ f"key_list length:{len(token_key_list)}")
+ return
+
+ future = self.kvc_store.mset_page_attn_blockwise_d2h(token_key_list, self.kv_caches, request.block_ids)
+ if not self.do_async_save:
+ get_future(future, SYNC_FUTURE_TIMEOUT)
+ else:
+ self._async_handler.add_save_request(request, 1)
+ self._async_handler.add_save_future(request, future)
+ logger.debug(f"mset_tensors_d2h success, request.request_id:{request.request_id}, "
+ f"key_list length:{len(token_key_list)}.")
+
+ def wait_for_save(self) -> None:
+ """
+ wait_for_save
+ """
+ if not self.is_producer:
+ return
+ connector_metadata = self._get_connector_metadata()
+ if not isinstance(connector_metadata, YuanRongConnectorMetadata):
+ raise ValueError("connector_metadata is not an instance of YuanRongConnectorMetadata")
+
+ if not connector_metadata.requests:
+ return
+
+ for request in connector_metadata.requests:
+ if self._async_handler is not None:
+ self._save_request_queue.put_nowait(request)
+ else:
+ self.do_save_request(request)
+
+ def get_num_new_matched_tokens(
+ self,
+ request: "Request",
+ num_computed_tokens: int,
+ ) -> tuple[int, bool]:
+ """
+ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens.
+
+ Args:
+ request (Request): the request object.
+ num_computed_tokens (int): the number of locally
+ computed tokens for this request
+
+ Returns:
+ the number of tokens that can be loaded from the
+ external KV cache beyond what is already computed.
+ """
+ num_computed_blocks = num_computed_tokens // self._block_size
+ num_tokens_to_check = align_to_block_size(len(request.prompt_token_ids), self._block_size)
+ prompt_blocks = num_tokens_to_check // self._block_size
+ num_external_hit_tokens = 0
+ if not self.is_producer:
+ self._skip_blocks[request.request_id] = num_computed_blocks
+ num_external_computed_tokens = len(request.prompt_token_ids) - num_computed_tokens - 1
+ self._ds_cached_blocks[request.request_id] = prompt_blocks
+ if self.do_async_save and num_external_computed_tokens > 0:
+ logger.info(f"request_id:{request.request_id}, num_computed_tokens:{num_computed_tokens}, "
+ f"num_external_computed_tokens:{num_external_computed_tokens}")
+ return num_external_computed_tokens, True
+
+ return num_external_computed_tokens, False
+ if ENABLE_PREFIX_CACHING:
+ tokens = request.prompt_token_ids
+ keys = generate_hash_sha256(num_computed_blocks, prompt_blocks, numpy.array(tokens), self._block_size, "-0")
+ if not keys:
+ logger.info(
+ "Reqid: %s, Total tokens %d, HBM hit tokens: %d, "
+ "need to load: 0", request.request_id, request.num_tokens, num_computed_tokens)
+ return 0, False
+
+ try:
+ exists = self.kvc_store.exist(keys) + [False]
+ except RuntimeError:
+ logger.info(
+ "Reqid: %s, Total tokens %d, HBM hit tokens: %d, "
+ "need to load: 0", request.request_id, request.num_tokens, num_computed_tokens)
+ return 0, False
+
+ num_external_hit_blocks = exists.index(False)
+ num_external_hit_tokens = num_external_hit_blocks * self._block_size
+
+ self._skip_blocks[request.request_id] = num_computed_blocks
+ self._ds_cached_blocks[request.request_id] = num_external_hit_blocks + num_computed_blocks
+
+ logger.info(
+ "Reqid: %s, Total tokens %d, HBM hit tokens: %d, "
+ "need to load: %d", request.request_id, request.num_tokens, num_computed_tokens,
+ num_external_hit_tokens)
+
+ if self.do_async_save and num_external_hit_tokens > 0:
+ return num_external_hit_tokens, True
+
+ return num_external_hit_tokens, False
+
+ def update_state_after_alloc(
+ self,
+ request: "Request",
+ blocks: "KVCacheBlocks",
+ num_external_tokens: int
+ ) -> None:
+ """
+ Update KVConnector state after block allocation.
+
+ If blocks were allocated, add to _requests_need_load,
+ such that we load the KVs in the next forward pass.
+ """
+ if num_external_tokens > 0:
+ block = blocks.get_unhashed_block_ids()
+ self._requests_need_load[request.request_id] = (request, [block])
+ logger.debug(f"_requests_need_load add request_id: {request.request_id}")
+
+ def build_connector_meta(
+ self,
+ scheduler_output: SchedulerOutput,
+ ) -> KVConnectorMetadata:
+ """
+ Build the connector metadata for this step.
+
+ This function should NOT modify any fields in the scheduler_output.
+ Also, calling this function will reset the state of the connector.
+
+ Args:
+ scheduler_output (SchedulerOutput): the scheduler output object.
+ """
+ meta = YuanRongConnectorMetadata(self.tp_size, self._block_size)
+ total_need_load = 0
+ for new_req in scheduler_output.scheduled_new_reqs:
+ if new_req.req_id in self._requests_need_load:
+ meta.add_request(request_id=new_req.req_id,
+ token_ids=new_req.prompt_token_ids,
+ block_ids=new_req.block_ids,
+ skip_block_num=self._skip_blocks.pop(new_req.req_id, 0),
+ ds_cached_block_num=self._ds_cached_blocks.pop(new_req.req_id, 0))
+ total_need_load += 1
+ else:
+ if self.is_producer:
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(new_req.req_id)
+ num_scheduled_tokens += new_req.num_computed_tokens
+ if len(new_req.prompt_token_ids) > num_scheduled_tokens:
+ self._delay_save[new_req.req_id] = RequestTracker.from_new_request(new_req.req_id,
+ new_req.prompt_token_ids,
+ new_req.block_ids,
+ num_scheduled_tokens)
+ else:
+ meta.add_request(request_id=new_req.req_id,
+ token_ids=new_req.prompt_token_ids,
+ block_ids=new_req.block_ids,
+ skip_block_num=self._skip_blocks.pop(new_req.req_id, 0),
+ ds_cached_block_num=self._ds_cached_blocks.pop(new_req.req_id, 0))
+
+ cached_reqs = scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(cached_reqs.req_ids):
+ new_block_ids = cached_reqs.new_block_ids[i]
+ resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
+
+ # NOTE(rob): here we rely on the resumed requests being
+ # the first N requests in the list scheduled_cache_reqs.
+ if not resumed_from_preemption:
+ if req_id in self._delay_save:
+ request_tracker = self._delay_save.get(req_id)
+ num_external_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id)
+ request_tracker.update(new_block_ids, num_external_scheduled_tokens)
+ if len(request_tracker.token_ids) <= request_tracker.num_scheduled_tokens:
+ del self._delay_save[req_id]
+ logger.debug(f"add delay save request, request id:{request_tracker.request_id}")
+ meta.add_request(request_id=request_tracker.request_id,
+ token_ids=request_tracker.token_ids,
+ block_ids=request_tracker.block_ids,
+ skip_block_num=self._skip_blocks.pop(request_tracker.request_id, 0),
+ ds_cached_block_num=self._ds_cached_blocks.pop(request_tracker.request_id, 0))
+
+ if req_id in self._requests_need_load:
+ # NOTE(rob): cached_req_data does not have the full
+ # list of token ids (only new tokens). So we look it
+ # up in the actual request object.
+ request = self._requests_need_load[req_id]
+ token_ids = request.all_token_ids[:len(request.prompt_token_ids)]
+ logger.debug(f"request_id:{request.request_id} resumed from preemption")
+ # NOTE(rob): For resumed req, new_block_ids is all of the block_ids for the request.
+ block_ids = new_block_ids
+ meta.add_request(request_id=req_id,
+ token_ids=token_ids,
+ block_ids=block_ids,
+ skip_block_num=self._skip_blocks.pop(req_id, 0),
+ ds_cached_block_num=self._ds_cached_blocks.pop(req_id, 0))
+ total_need_load += 1
+ if self.do_async_save:
+ for req_id, (req, block_ids) in self._requests_need_load.items():
+ if not block_ids:
+ logger.debug(
+ "Skipping adding request %s to ConnectorMetadata, "
+ "as there are no remote blocks to pull", req_id)
+ continue
+
+ meta.add_request(
+ request_id=req_id,
+ token_ids=req.prompt_token_ids,
+ block_ids=block_ids,
+ skip_block_num=self._skip_blocks.pop(req_id, 0),
+ ds_cached_block_num=self._ds_cached_blocks.pop(req_id, 0),
+ need_save=False)
+ total_need_load += 1
+
+ logger.debug(f"total_need_load:{total_need_load}, self._requests_need_load:{len(self._requests_need_load)}")
+ # Clear the list once workers start the transfers
+ if total_need_load != len(self._requests_need_load):
+ logger.error(f"total_need_load={total_need_load} "
+ f"is not equal to requests_need_load={len(self._requests_need_load)}")
+ raise ValueError("total_need_load is not equal to requests_need_load")
+ self._requests_need_load.clear()
+ return meta
+
+ def request_finished(
+ self,
+ request: "Request",
+ block_ids: list[int],
+ ) -> tuple[bool, Optional[dict[str, Any]]]:
+ """
+ request_finished
+ """
+ # Return True to indicate that saving may be happening asynchronously.
+ if self.is_producer:
+ return self.do_async_save, None
+
+ return False, None
+
+ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
+ """
+ Initialize KV caches from forward_context.
+
+ Args:
+ forward_context: forward_context.
+ """
+ attn_metadata = forward_context.attn_metadata
+ for layer_name in forward_context.no_compile_layers:
+ attn_layer = forward_context.no_compile_layers[layer_name]
+ kv_layer = attn_layer.kv_cache[forward_context.virtual_engine]
+ self.is_ms_non_mla_type: bool = isinstance(kv_layer, tuple) and len(kv_layer) == 2
+ self.is_ms_mla = os.getenv("vLLM_MODEL_BACKEND", None) == "MindFormers" and not self.is_ms_non_mla_type
+ self.is_mla = isinstance(attn_metadata, MLACommonMetadata) or self.is_ms_mla
+ if layer_name not in self.layer_name_list:
+ self.layer_name_list.append(layer_name)
+ logger.debug(f"_init_kv_caches_from_forward_context, layer_name:{layer_name}")
+ if not self.is_mla:
+ self.key_caches.append(kv_layer[0])
+ self.value_caches.append(kv_layer[1])
+ elif self.is_ms_mla:
+ self.kv_caches.append(kv_layer[0])
+ else:
+ self.kv_caches.append(kv_layer)
+
+
+def extract_number(s: str) -> Optional[int]:
+ """extract number"""
+ parts = s.split('.')
+ for part in parts:
+ if part.isdigit():
+ return int(part)
+ return None
+
+
+def align_to_block_size(num_tokens: int, block_size: int) -> int:
+ """
+ Align the number of tokens to the block size.
+ """
+ return (num_tokens + block_size - 2) // block_size * block_size
+
+
+def generate_hash_sha256(
+ block_start_index: int,
+ block_end_index: int,
+ token_ids: numpy.ndarray,
+ block_size: int,
+ external_key: str
+) -> List[str]:
+ """
+ generate kv_cache token key.
+
+ Args:
+ block_id_num: number of block ids.
+ token_ids: token ids
+ block_size: block size of vllm
+ external_key: additional key
+ """
+ hash_list = []
+ for block_index in range(block_start_index, block_end_index):
+ end_index = (block_index + 1) * block_size
+ input_ids = token_ids[:end_index]
+ input_ids_bytes = input_ids.tobytes()
+ token_hash = hashlib.sha256(input_ids_bytes).hexdigest()
+ hash_list.append(token_hash + external_key)
+ return hash_list
+
+
+def get_future(fut: Future, timeout: int = FUTURE_TIMEOUT) -> RequestStatus:
+ """get future"""
+ try:
+ failed_list = fut.get(timeout)
+ except TimeoutError:
+ return RequestStatus.WAITING
+
+ if len(failed_list) != 0:
+ logger.error(f"failed_list: {failed_list}")
+ return RequestStatus.TIMEOUT
+
+ return RequestStatus.FINISHED
--
2.33.0
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/openeuler/yuanrong-datasystem.git
git@gitee.com:openeuler/yuanrong-datasystem.git
openeuler
yuanrong-datasystem
yuanrong-datasystem
master

搜索帮助