diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index efd917b0a3cef0cc0d11212dc25de9dc60166260..2bfd67541594f0354701fc5aa8d6d7218e2e3b83 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -22,6 +22,8 @@ import warnings import msadapter # noqa: F401 from vllm_mindspore.ray_patch import patch_ray +is_dispatch_req_all_depend_p0 = True + patch_ray() if "vllm" in sys.modules: @@ -565,11 +567,15 @@ vllm.entrypoints.cli.serve.CoreEngine = MsCoreEngine vllm.v1.engine.core_client.CoreEngine = MsCoreEngine vllm.v1.utils.CoreEngine = MsCoreEngine -from vllm.v1.engine.core_client import DPAsyncMPClient +if is_dispatch_req_all_depend_p0: + # Dispatch the request based on the status stored on p0, + # instead of load-balance statuses published by p1. + from vllm.v1.engine.core_client import DPAsyncMPClient -DPAsyncMPClient.get_core_engine_for_request = get_core_engine_for_request -DPAsyncMPClient.add_request_async = add_request_async -DPAsyncMPClient.process_engine_outputs = staticmethod(process_engine_outputs) + DPAsyncMPClient.get_core_engine_for_request = get_core_engine_for_request + DPAsyncMPClient.add_request_async = add_request_async + DPAsyncMPClient.process_engine_outputs = staticmethod( + process_engine_outputs) from vllm.v1.engine.processor import Processor