From e1ace83f2091f676ccbc7a1886595e24acd6056f Mon Sep 17 00:00:00 2001 From: zlq2020 Date: Fri, 10 Oct 2025 16:18:01 +0800 Subject: [PATCH] V1 support no cp schedule --- tests/st/python/test_scheduler.py | 286 +++++++ vllm_mindspore/__init__.py | 10 +- vllm_mindspore/engine/arg_utils.py | 13 +- vllm_mindspore/v1/core/sched/scheduler.py | 946 +++++++++++++++++----- vllm_mindspore/v1/metrics/__init__.py | 0 vllm_mindspore/v1/metrics/stats.py | 58 ++ 6 files changed, 1102 insertions(+), 211 deletions(-) create mode 100644 tests/st/python/test_scheduler.py create mode 100644 vllm_mindspore/v1/metrics/__init__.py create mode 100644 vllm_mindspore/v1/metrics/stats.py diff --git a/tests/st/python/test_scheduler.py b/tests/st/python/test_scheduler.py new file mode 100644 index 00000000..ce66872b --- /dev/null +++ b/tests/st/python/test_scheduler.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/tests/v1/core/ +# test_scheduler.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test v1 Scheduler.""" +import os +from typing import Optional +import pytest +from unittest.mock import patch + +import vllm_mindspore +import mindspore as ms + +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheTensor) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager +from vllm import LLM, SamplingParams + +from vllm_mindspore.v1.core.sched.scheduler import EnhancedScheduler + +from tests.st.python.utils.cases_parallel import cleanup_subprocesses + +EOS_TOKEN_ID = 151645 +MODEL_PATH = "/home/workspace/mindspore_dataset/weight/Qwen3-0.6B" + + +def teardown_function(): + cleanup_subprocesses() + + +def create_scheduler( + model: str = MODEL_PATH, + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, + enable_prefix_caching: bool = True, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: Optional[int] = None, +): + if max_model_len is None: + max_model_len = max_num_batched_tokens + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=enable_chunked_prefill, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, ms.float32, + False)) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return EnhancedScheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests(num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + ) + requests.append(request) + return requests + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_default_schedule(): + scheduler = create_scheduler(max_num_batched_tokens=1024, + enable_prefix_caching=True, + enable_chunked_prefill=True) + # 15 * 1000 > 1024 + requests = create_requests(num_requests=15, num_tokens=100) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 11 + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + + # Verify part requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert (num_tokens == len(requests[int(req_id)].prompt_token_ids) + or num_tokens == 24) + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 4 + assert len(scheduler.running) == 11 + for i, request in enumerate(requests): + if i < len(scheduler.running): + assert scheduler.running[i] == request + else: + assert request in scheduler.waiting + + # second scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 4 + assert len(output.scheduled_cached_reqs) == 1 + assert len(output.finished_req_ids) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 15 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_enhanced_schedule(): + scheduler = create_scheduler(max_num_batched_tokens=1024, + enable_prefix_caching=False, + enable_chunked_prefill=False) + # 15 * 1000 > 1024 + requests = create_requests(num_requests=15, num_tokens=100) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 10 + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + + # Verify part requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 5 + assert len(scheduler.running) == 10 + for i, request in enumerate(requests): + if i < len(scheduler.running): + assert scheduler.running[i] == request + else: + assert request in scheduler.waiting + + # second scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 5 + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 15 + + +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "VLLM_MS_MODEL_BACKEND": "Native", + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "1", +} + + +@patch.dict(os.environ, env_vars) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_enhanced_schedule_end_to_end(): + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=4, top_k=1) + + # Create an LLM. + llm = LLM(model=MODEL_PATH, + gpu_memory_utilization=0.9, + tensor_parallel_size=1, + max_num_seqs=10, + max_num_batched_tokens=128, + max_model_len=256, + enable_prefix_caching=False, + enable_chunked_prefill=False) + + # long request will be ignored. + prompts1 = [ + "<|im_start|>user\n请你按照指定的情感分类标准,将提供的文本准确归类到中性、" + "负面或正面三个类别中的某一类。在进行分类前,可先结合文本所表达的态度、语气及" + "情感倾向进行简要分析:若文本仅客观陈述事实、未明显流露积极或消极倾向,则归为中性;" + "若文本包含肯定、喜爱、满意等积极情绪及正向评价,归为正面;若文本带有否定、不满、厌恶" + "等消极情绪及负面评价,归为负面。本次需要分类的文本内容为:我认为这次假期还可以。" + "请先简要说明分类依据(如文本中关键表述所传递的情感强度、是否存在明确的正负向倾向等)," + "再给出最终的情感分类结果。\n情感:<|im_end|>\n" + ] + outputs = llm.generate(prompts1, sampling_params) + finish_reason = outputs[0].outputs[0].finish_reason + assert finish_reason == "length" + + # norm requests will be scheduled. + prompts2 = [ + "<|im_start|>user\n将文本分类为中性、负面或正面。 " + "\n文本:我认为这次假期还可以。 \n情感:<|im_end|>\n" + "<|im_start|>assistant\n\n\n\n\n", + ] * 5 + outputs = llm.generate(prompts2, sampling_params) + except_list = ['情感:中性'] + for _, output in enumerate(outputs): + generated_text = output.outputs[0].text + assert generated_text == except_list[ + 0], f"Expected: {except_list[0]}, but got: {generated_text}" diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 58c78e07..759a64f1 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -521,11 +521,6 @@ V1Worker.__init__ = (wrapper_worker_bind_cpu( V1Worker.init_device = wrapper_worker_init_device(V1Worker.init_device) V1Worker.compile_or_warm_up_model = compile_or_warm_up_model -from vllm_mindspore.v1.core.sched.scheduler import update_from_output -from vllm.v1.core.sched.scheduler import Scheduler - -Scheduler.update_from_output = update_from_output - from vllm_mindspore.v1.executor.multiproc_executor import ( executor_ensure_worker_termination, ) from vllm.v1.executor.multiproc_executor import MultiprocExecutor @@ -578,4 +573,9 @@ from vllm.v1.engine.processor import Processor Processor._validate_sampling_params = v1_process_validate_sampling_params Processor._validate_structured_output = v1_process_validate_structured_output +from vllm.v1.metrics.stats import IterationStats +from vllm_mindspore.v1.metrics.stats import update_from_output + +IterationStats.update_from_output = update_from_output + check_ready() diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index f7a3d6aa..5c498ab9 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -219,17 +219,18 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: def _set_default_args_v1(self, usage_context: UsageContext) -> None: """Set Default Arguments for V1 Engine.""" - # V1 always uses chunked prefills. - self.enable_chunked_prefill = True + # Original vLLM V1 always uses chunked prefills. + # EnhancedScheduler can support enable_chunked_prefill=False + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = True # V1 enables prefix caching by default. if self.enable_prefix_caching is None: self.enable_prefix_caching = True - # V1 should use the new scheduler by default. - # Swap it only if this arg is set to the original V0 default - if self.scheduler_cls == EngineArgs.scheduler_cls: - self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" + # Use EnhancedScheduler in place of Scheduler + self.scheduler_cls = ( + "vllm_mindspore.v1.core.sched.scheduler.EnhancedScheduler") # vllm-mindspore: Get device memory will initialize device runtime, which # will be inherited by the child process in fork mode, resulting in diff --git a/vllm_mindspore/v1/core/sched/scheduler.py b/vllm_mindspore/v1/core/sched/scheduler.py index 072a5f28..45975412 100644 --- a/vllm_mindspore/v1/core/sched/scheduler.py +++ b/vllm_mindspore/v1/core/sched/scheduler.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -# Adapted from -# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/v1/core/sched/scheduler.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/ +# vllm/v1/core/sched/scheduler.py # # Copyright 2025 Huawei Technologies Co., Ltd. -# Copyright 2024-2025 The vLLM team. +# Copyright 2025 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,220 +21,766 @@ # noqa: G004 -from collections import defaultdict -from typing import Optional +import time +from collections import defaultdict, deque +from collections.abc import Iterable +from typing import Optional, Union +from vllm.config import VllmConfig +from vllm.distributed.kv_events import KVEventBatch from vllm.logger import init_logger -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.utils import check_stop -from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs, FinishReason +from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, + EngineCoreOutputs, FinishReason) +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.v1.structured_output import StructuredOutputManager logger = init_logger(__name__) -def update_from_output( - self, - scheduler_output: SchedulerOutput, - model_runner_output: ModelRunnerOutput, -) -> dict[int, EngineCoreOutputs]: - sampled_token_ids = model_runner_output.sampled_token_ids - spec_token_ids = model_runner_output.spec_token_ids - logprobs = model_runner_output.logprobs - prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict - num_scheduled_tokens = scheduler_output.num_scheduled_tokens - - new_running: list[Request] = [] - outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None - - # Add by vllm-mindspore begin: - running_req_ids = [req.request_id for req in self.running] - # abort_req_ids used to keep track of failed requests - # caused by model execution exception - abort_req_ids: list[str] = [] - # Add by vllm-mindspore end. - - # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below - # loop can be a performance bottleneck. We should do our best to avoid - # expensive operations inside the loop. - for request in self.running: - req_id = request.request_id - # Add by vllm-mindspore begin: - # None sampled_token_ids comes from exception model execution, - # set them to abort list - # to keep main scheduler task running right. - if sampled_token_ids is None: - logger.warning( - 'Process aborted request %s from running requests %s', req_id, - running_req_ids) - outputs[request.client_index].append( - EngineCoreOutput(request_id=req_id, - new_token_ids=[], - finish_reason=FinishReason.ABORT, - new_logprobs=None, - new_prompt_logprobs_tensors=None, - stop_reason=request.stop_reason, - events=request.take_events())) - abort_req_ids.append(req_id) - continue - # Add by vllm-mindspore end. +# Refer to +# https://github.com/vllm-project/vllm/blob/main/vllm/ +# v1/core/sched/scheduler.py +class EnhancedScheduler(Scheduler): + """ + Enhance the V1 scheduler with prefill-first scheduling strategy. + """ + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__(vllm_config, kv_cache_config, + structured_output_manager, mm_registry, + include_finished_set, log_stats) + self.scheduled_req_ids: set[str] = set() + self.ignored_req_ids: set[str] = set() + self.running: list[Request] = [] + + self.max_kvcache_size = (self.kv_cache_config.num_blocks * + self.block_size) + + def schedule(self) -> SchedulerOutput: + # default + if self.scheduler_config.chunked_prefill_enabled: + return super().schedule() + + return self._schedule_no_chunked() + + def _schedule_no_chunked(self) -> SchedulerOutput: + """ + 1. Supports no cp and the prefill-first scheduling strategy. + 2. Does not support PD disaggregation at present. + """ + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + structured_output_request_ids: dict[str, int] = {} + + req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # Record scheduled LoRA requests. + scheduled_loras: set[int] = set() + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the WAITING prefill requests. + skipped_waiting_requests: deque[Request] = deque() + req_index = 0 + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break - num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - if num_tokens_scheduled == 0: - # The request was not scheduled in this step. - new_running.append(request) - continue - - req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[req_index] - - scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) - if scheduled_spec_token_ids: - # num_computed_tokens represents the number of tokens - # processed in the current step, considering scheduled - # tokens and rejections. If some tokens are rejected, - # num_computed_tokens is decreased by the number of rejected - # tokens, where is given by: - # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). - num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - - len(generated_token_ids)) - request.num_computed_tokens -= num_tokens_rejected - spec_decoding_stats = self.make_spec_decoding_stats( - spec_decoding_stats, - num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) - - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) - # OPTIMIZATION: Avoid list(set) if the set is empty. - if cached_encoder_input_ids: - for input_id in list(cached_encoder_input_ids): - mm_positions = request.mm_positions[input_id] - start_pos = mm_positions.offset - num_tokens = mm_positions.length - if start_pos + num_tokens <= request.num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) - - stopped = False - new_logprobs = None - new_token_ids = generated_token_ids - kv_transfer_params = None - - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner - # to return empty token ids for the request. - for num_new, output_token_id in enumerate(new_token_ids, 1): - request.append_output_token_ids(output_token_id) - - # Check for stop and update request state. - # This must be called before we make the EngineCoreOutput. - stopped = check_stop(request, self.max_model_len) - if stopped: - kv_transfer_params = self._free_request(request) - del new_token_ids[num_new:] # Trim new tokens if needed. + request = self.waiting[0] + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) + self.waiting.pop_request() + skipped_waiting_requests.appendleft(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request and ( + len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras): + # Scheduling would exceed max_loras, skip. + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + new_computed_blocks, num_new_local_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + else: + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + # KVTransfer: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + # Number of tokens to be scheduled. + else: + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + assert num_new_tokens > 0 + + prompt_limit = self._get_prompt_limit(request) + prompt_limit = min(prompt_limit, self.max_kvcache_size) + if num_new_tokens > prompt_limit: + logger.warning( + "Request(%s) prompt has (%d tokens) , and it \ + exceeds system limit of %d. It will be ignored/abort.", + request.request_id, + num_new_tokens, + prompt_limit, + ) + request.status = RequestStatus.FINISHED_IGNORED + self.ignored_req_ids.add(request.request_id) + self.waiting.popleft() + continue + + # skip if new tokens exceed token_budget + if num_new_tokens > token_budget: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async, + ) + if new_blocks is None: + # The request cannot be scheduled. break - # Extract sample logprobs if needed. - if request.sampling_params.logprobs is not None and logprobs: - # NOTE: once we support N tokens per step (spec decode), - # the outer lists can be of length > 1. - new_logprobs = logprobs.slice(req_index, req_index + 1) - - if new_token_ids and self.structured_output_manager.should_advance( - request): - """ - NOTE: structured_output_request - should not be None if use_structured_output, we have - check above, so safe to ignore type warning - """ - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) - - # Add newly generated spec token ids to the request. - if spec_token_ids is not None: - if self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids[req_index]) + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + + self.waiting.popleft() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + if request.use_structured_output: + structured_output_request_ids[request.request_id] = req_index + req_index += 1 + + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + self.scheduled_req_ids.add(request.request_id) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) else: - request.spec_token_ids = spec_token_ids[req_index] + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_block_ids[request.request_id] = ( + self.kv_cache_manager.get_block_ids(request.request_id)) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.extendleft(skipped_waiting_requests) + + # Second, schedule the RUNNING decode requests if no prefill reqs. + if len(self.scheduled_req_ids) == 0: + req_index = 0 + token_budget = self.max_num_scheduled_tokens + + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + # The request has been scheduled. + if request.request_id in self.scheduled_req_ids: + req_index += 1 + continue + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + # decode only + if (request.num_tokens - request.num_computed_tokens) != 1: + raise RuntimeError( + "The running queue only contains decode requests.") + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model + # len. This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_budget) + + if num_new_tokens == 0: + req_index += 1 + continue + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id + not in scheduled_loras): + req_index += 1 + continue + + num_draft_tokens = max( + num_new_tokens + request.num_computed_tokens - + request.num_tokens, 0) + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_draft_tokens=num_draft_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, + scheduled_timestamp) + + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + self.scheduled_req_ids.add(request.request_id) + scheduled_running_reqs.append(request) + if request.use_structured_output: + structured_output_request_ids[ + request.request_id] = req_index + req_to_new_block_ids[request.request_id] = ( + new_blocks.get_block_ids()) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Record scheduled LoRA requests. + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs) <= len(self.running) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + grammar_bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + + resumed_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=False, + ) for req in scheduled_running_reqs + ] + scheduled_cached_reqs = resumed_reqs_data + running_reqs_data + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=scheduled_cached_reqs, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, # type: ignore[has-type] + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + events = self.kv_cache_manager.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g.spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + self.requests[req_id].num_computed_tokens += num_scheduled_token + + self.finished_req_ids = set() # type: ignore[has-type, var-annotated] + return scheduler_output + + # Refer to vllm\core\scheduler.py Scheduler._get_prompt_limit + def _get_prompt_limit(self, seq_group) -> int: + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + if len(self.scheduled_req_ids) > 0: + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + if request.status == RequestStatus.RUNNING: + # Clear the scheduled reqs + self.scheduled_req_ids.discard(request.request_id) + + super().finish_requests(request_ids, finished_status) + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> EngineCoreOutputs: + if len(self.scheduled_req_ids) > 0: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + for request in self.running: + req_id = request.request_id + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + continue + # Clear the scheduled reqs + if req_id in self.scheduled_req_ids: + self.scheduled_req_ids.discard(req_id) + + return self._update_from_output(scheduler_output, model_runner_output) + + def _update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + new_running: list[Request] = [] + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: Optional[SpecDecodingStats] = None - # Get prompt logprobs for this request. - prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or kv_transfer_params: + # Add by vllm-mindspore begin: + running_req_ids = [req.request_id for req in self.running] + # abort_req_ids used to keep track of failed requests + # caused by model execution exception + abort_req_ids: list[str] = [] + # Add by vllm-mindspore end. - # Add EngineCoreOutput for this Request. - outputs[request.client_index].append( - EngineCoreOutput( - request_id=req_id, - new_token_ids=new_token_ids, - finish_reason=request.get_finished_reason(), - new_logprobs=new_logprobs, - new_prompt_logprobs_tensors=prompt_logprobs_tensors, - stop_reason=request.stop_reason, - events=request.take_events(), - kv_transfer_params=kv_transfer_params, - num_cached_tokens=request.num_cached_tokens, - )) + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below + # loop can be a performance bottleneck. We should do our best to avoid + # expensive operations inside the loop. + for request in self.running: + req_id = request.request_id + # Add by vllm-mindspore begin: + # None sampled_token_ids comes from exception model execution, + # set them to abort list + # to keep main scheduler task running right. + if sampled_token_ids is None: + logger.warning( + 'Process aborted request %s from running requests %s', + req_id, running_req_ids) + outputs[request.client_index].append( + EngineCoreOutput(request_id=req_id, + new_token_ids=[], + finish_reason=FinishReason.ABORT, + new_logprobs=None, + new_prompt_logprobs_tensors=None, + stop_reason=request.stop_reason, + events=request.take_events())) + abort_req_ids.append(req_id) + continue + # Add by vllm-mindspore end. + + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[req_index] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens, where is given by: + # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). + num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) + request.num_computed_tokens -= num_tokens_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=len(scheduled_spec_token_ids), + num_accepted_tokens=len(generated_token_ids) - 1) + + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if cached_encoder_input_ids: + for input_id in list(cached_encoder_input_ids): + mm_positions = request.mm_positions[input_id] + start_pos = mm_positions.offset + num_tokens = mm_positions.length + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + kv_transfer_params = self._free_request(request) + del new_token_ids[num_new:] # Trim new tokens if needed. + break + + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and self.structured_output_manager.should_advance( + request): + """ + NOTE: structured_output_request + should not be None if use_structured_output, we have + check above, so safe to ignore type warning + """ + request.structured_output_request.grammar.accept_tokens( + req_id, new_token_ids) # type: ignore[union-attr] + + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + if self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + # Needs to happen after new_token_ids are accepted. + request.spec_token_ids = metadata.grammar.validate_tokens( + spec_token_ids[req_index]) # type: ignore[union-attr] + else: + request.spec_token_ids = spec_token_ids[req_index] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or kv_transfer_params: + + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + num_cached_tokens=request.num_cached_tokens, + )) - else: - # Invariant: EngineCore returns no partial prefill outputs. - assert not prompt_logprobs_tensors - - if not stopped: - new_running.append(request) - - # Add by vllm-mindspore begin: - # make failed requests finished to make the server - # can continue to process new request - if len(abort_req_ids) > 0: - logger.warning('Aborted requests are %s', abort_req_ids) - self.finish_requests(abort_req_ids, RequestStatus.FINISHED_ABORTED) - # Add by vllm-mindspore end. - - # KV Connector: update state for finished KV Transfers. - self._update_from_kv_xfer_finished(model_runner_output) - - # Return the cached request data to the queue so they can be reused. - for req_data in scheduler_output.scheduled_cached_reqs: - # NOTE(rob): since we free stopped reqs above, adding stopped reqs - # to _cached_reqs_data will cause a memory leak. - if req_data.req_id not in self.finished_req_ids: - self._cached_reqs_data[req_data.req_id].append(req_data) - - self.running = new_running - - # Create EngineCoreOutputs for all clients that have requests with - # outputs in this step. - engine_core_outputs = { - client_index: EngineCoreOutputs(outputs=outs) - for client_index, outs in outputs.items() - } - - finished_req_ids = self.finished_req_ids_dict - if finished_req_ids: - # Include ids of requests that finished since last outputs - # were sent. - for client_index, finished_set in finished_req_ids.items(): - # Set finished request set in EngineCoreOutputs for this client. - if (eco := engine_core_outputs.get(client_index)) is not None: - eco.finished_requests = finished_set else: - engine_core_outputs[client_index] = EngineCoreOutputs( - finished_requests=finished_set) - finished_req_ids.clear() + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + if not stopped: + new_running.append(request) + + # Add by vllm-mindspore begin: + # make failed requests finished to make the server + # can continue to process new request + if len(abort_req_ids) > 0: + logger.warning('Aborted requests are %s', abort_req_ids) + self.finish_requests(abort_req_ids, RequestStatus.FINISHED_ABORTED) + + self._process_ignored_reqs(outputs) + # Add by vllm-mindspore end. - if engine_core_outputs: - # Return stats to only one of the front-ends. - next(iter(engine_core_outputs.values())).scheduler_stats = ( - self.make_stats(spec_decoding_stats)) + # KV Connector: update state for finished KV Transfers. + self._update_from_kv_xfer_finished(model_runner_output) + + # Return the cached request data to the queue so they can be reused. + for req_data in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): since we free stopped reqs above, adding stopped reqs + # to _cached_reqs_data will cause a memory leak. + if req_data.req_id not in self.finished_req_ids: + self._cached_reqs_data[req_data.req_id].append(req_data) + + self.running = new_running + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if engine_core_outputs: + # Return stats to only one of the front-ends. + next(iter(engine_core_outputs.values())).scheduler_stats = ( + self.make_stats(spec_decoding_stats)) + + return engine_core_outputs + + def _process_ignored_reqs(self, outputs): + if len(self.ignored_req_ids) == 0: + return + + for req_id in self.ignored_req_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + outputs[request.client_index].append( + EngineCoreOutput(request_id=req_id, + new_token_ids=[], + finish_reason=request.get_finished_reason(), + new_logprobs=None, + new_prompt_logprobs_tensors=None, + stop_reason=request.stop_reason, + events=request.take_events())) - return engine_core_outputs + # Clear the req source + self._free_request(request) + self.ignored_req_ids.clear() diff --git a/vllm_mindspore/v1/metrics/__init__.py b/vllm_mindspore/v1/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_mindspore/v1/metrics/stats.py b/vllm_mindspore/v1/metrics/stats.py new file mode 100644 index 00000000..2e59bb7e --- /dev/null +++ b/vllm_mindspore/v1/metrics/stats.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Functions are adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/v1/metrics/stats.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from vllm.v1.engine import EngineCoreOutput +from vllm.v1.metrics.stats import LoRAStats, RequestStateStats + + +# Refer to vllm/v1/metrics/stats.py +# This patch function can be removed in vLLM 0.11 +def update_from_output(self, output: "EngineCoreOutput", + engine_core_timestamp: float, is_prefilling: bool, + prompt_len: int, req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats]): + num_new_generation_tokens = len(output.new_token_ids) + + self.num_generation_tokens += num_new_generation_tokens + if is_prefilling: + # When a new request be ignored or abort, new token=0, + # 'assert num_new_generation_tokens > 0' is removed in 0.11 + self.num_prompt_tokens += prompt_len + + first_token_latency = self._time_since(req_stats.arrival_time) + self.time_to_first_tokens_iter.append(first_token_latency) + + req_stats.num_generation_tokens += num_new_generation_tokens + + # Process request-level engine core events + if output.events is not None: + self.update_from_events(output.request_id, output.events, + is_prefilling, req_stats, lora_stats) + + # Process the batch-level "new tokens" engine core event + if is_prefilling: + req_stats.first_token_ts = engine_core_timestamp + else: + tpot = engine_core_timestamp - req_stats.last_token_ts + self.time_per_output_tokens_iter.append(tpot) + + req_stats.last_token_ts = engine_core_timestamp -- Gitee