From 184110eb24dca41261d6284084e052d095486869 Mon Sep 17 00:00:00 2001 From: tronzhang Date: Sun, 28 Sep 2025 20:08:42 +0800 Subject: [PATCH] disable cache for multimodal when dp>1 --- vllm_mindspore/platforms/ascend.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 8e39c3d15..fdff477a3 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -105,6 +105,19 @@ class AscendPlatform(Platform): model_config = vllm_config.model_config model_config.disable_cascade_attn = True + # Cache between p0 and p1 effective only one-on-one situations. In data + # parallelelism, it is a one-to-many scenario, cache should be disabled. + if (not model_config.disable_mm_preprocessor_cache + and parallel_config.data_parallel_size > 1): + if model_config.multimodal_config is None: + raise RuntimeError( + "For disable_mm_preprocessor_cache, multimodal_config " + "should not be None!") + model_config.multimodal_config.disable_mm_preprocessor_cache = True + logger.info( + "Disable mm preprocessor cache for data parallel size %d.", + parallel_config.data_parallel_size) + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): -- Gitee