diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index ebd9628e1e0b6454370d9443bd5636c68c63b235..df1c71711d37453f103b3651c7ea47630de6f87f 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -62,9 +62,6 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): 'model_config', None).get('multi_latent_attention', False) self.use_ringmla = is_use_ringmla(vllm_config, mf_config) self.mf_config.model.model_config.use_fused_mla = self.use_ringmla - # run chunked graph independently only if ringmla enabled. - self.set_chunked_flags = not self.use_ringmla - self.set_decode_flags = False build_mf_context(self.mf_config) mf_par_ctx = build_parallel_context(self.mf_config) @@ -367,24 +364,39 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): model_inputs["hidden_states"] = convert_pin( kwargs["previous_hidden_states"]) - self.network.phase = "prefill" if is_prefill else \ - "chunked" if is_ringmla_chunked else "increment" - if (not self.set_flags or not self.set_chunked_flags - or self.is_eager_mode): - self.set_flags = True - self.network.add_flags_custom_mcore(is_prefill=True) - if hasattr(self.network, 'add_flags_chunked'): - # chunked means 3-rd graph "chunked" - self.network.add_flags_chunked(is_chunked=is_ringmla_chunked) - # ringmla_chunked means computing chunked-prefills on ringmla - self.set_chunked_flags |= is_ringmla_chunked - elif not self.set_decode_flags or self.is_eager_mode: - self.network.add_flags_custom_mcore(is_prefill=False) - if hasattr(self.network, 'add_flags_chunked'): - self.network.add_flags_chunked(is_chunked=False) - self.set_decode_flags = True - - hidden_states = self.network(**model_inputs) + def _set_network_flags(prefill_flag, chunked_flag): + self.network.add_flags_custom_mcore(is_prefill=prefill_flag) + if (hasattr(self.network, "add_flags_chunked") + and is_ringmla_chunked): + self.network.add_flags_chunked(is_chunked=chunked_flag) + + if self.is_eager_mode: + # In eager_mode, there is no need to set flags repeatedly in + # decoding, until there is new prefill or chunked prediction. + need_set_flag = is_prefill or is_ringmla_chunked + else: + # In graph_mode, there is no need to set flags until all inference + # stages have been executed (including prefill/decode, + # and chunked only if ringmla is enabled). + need_set_flag = (not self.has_prefill_warmup + or not self.has_chunked_warmup) + self.network.phase = "prefill" if is_prefill \ + else "chunked" if is_ringmla_chunked else "decode" + + # The value of has_prefill_warmup and has_chunked_warmup indicates + # whether the corresponding inference graph has been executed. + # If ringmla is disabled, the value of has_chunked_warmup would be + # initialized to True, indicating that there is no need to execute + # chunked graph. + if need_set_flag: + _set_network_flags(True, True) + hidden_states = self.network(**model_inputs) + _set_network_flags(False, False) + self.has_prefill_warmup = True + self.has_chunked_warmup = (not self.use_ringmla + or is_ringmla_chunked) + else: + hidden_states = self.network(**model_inputs) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm_mindspore/model_executor/models/mindone_models/transformers.py b/vllm_mindspore/model_executor/models/mindone_models/transformers.py index d7b739f745708438a472722fb1352783a834d0f6..c0d24e7ca7d6d4087bbc0372ac49a6ea5e1d5735 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/transformers.py +++ b/vllm_mindspore/model_executor/models/mindone_models/transformers.py @@ -542,8 +542,8 @@ class TransformersForCausalLM(MindONEModelBase): q_seq_lens, block_tables) # for dummy_attention_metadata - if is_prefill and not self.set_flags: #type: ignore - self.set_flags = True + if is_prefill and not self.has_prefill_warmup: #type: ignore + self.has_prefill_warmup = True set_model_context("is_prefill", is_prefill) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 1e4627c6b45044ac404f40bbfb5115a81bc66c19..bfd1651e324a31b191ed081a8ba1ce154e886d33 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -138,8 +138,9 @@ class MsModelBase: self.num_layers = self.model_config.get_num_layers( self.parallel_config) - self.set_flags: bool = False - self.set_chunked_flags: bool = False + self.use_ringmla: bool = False + self.has_prefill_warmup: bool = False + self.has_chunked_warmup: bool = not self.use_ringmla self.kv_caches: list[Any] = [] self.casual_mask = LowerTriangularMask( dtype=self.model_config.dtype, @@ -290,11 +291,14 @@ class MsModelBase: seq_lengths = ms.Tensor([input_len], dtype=ms.int32) q_seq_lens_np = np.array([input_len], dtype=np.int32) seq_lens_np = np.array([input_len], dtype=np.int32) - context_lens_tensor = ms.Tensor([0], dtype=ms.int32) if not \ - self.set_flags else ms.Tensor([1], dtype=ms.int32) - # create input for chunked graph. + # context len is 0 for prefill, and 1 for chunked and decode. + context_lens_tensor = ms.Tensor([0], dtype=ms.int32) if not ( + self.has_chunked_warmup) else ms.Tensor([1], dtype=ms.int32) + # num_prompt_tokens is equal to seq_len for prefill and decode, + # and equal to seq_len + 1 for chunked. num_prompt_tokens = seq_lengths + 1 \ - if (self.set_flags and not self.set_chunked_flags) else seq_lengths + if (self.has_prefill_warmup and not self.has_chunked_warmup) \ + else seq_lengths block_tables = ms.Tensor([[0]], dtype=ms.int32) slot_mapping = [-1 for _ in range(input_len)] slot_mapping = ms.Tensor(slot_mapping, dtype=ms.int32) @@ -308,7 +312,7 @@ class MsModelBase: context_lens=context_lens_tensor, # To enforce prefill and decode are both complied in warmup process. # So set max_context_lens to 0 for prefill and 1 for decode. - max_context_lens=0 if not self.set_flags else 1, + max_context_lens=0 if not self.has_prefill_warmup else 1, query_start_loc=None, num_prompt_tokens=num_prompt_tokens) @@ -510,8 +514,8 @@ class NativeModel(MsModelBase): inputs_embeds) # for dummy_attention_metadata - if is_prefill and not self.set_flags: - self.set_flags = True + if is_prefill and not self.has_prefill_warmup: + self.has_prefill_warmup = True # eager mode if self.is_eager_mode: diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index d7d07cd3b5d9400e12f8a54cad833847bca1d8ed..3e2e45370ef7ab87378c894136db83f03f5a3e85 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -30,7 +30,8 @@ def compile_or_warm_up_model(self) -> None: # Since prefill is done previously, we do decode here. default_max_num_reqs = 1 # For MindSpore, we only do one more decode here. - if hasattr(self.model_runner.model, 'set_chunked_flags'): + if hasattr(self.model_runner.model, 'has_chunked_warmup') \ + and not self.model_runner.model.has_chunked_warmup: logger.info("Warmup for chunked graph.") self.model_runner._dummy_run(num_tokens=default_max_num_reqs)