diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index b5a288cd30675fd7f1cbcf2f6c9e32579faf9b3c..b0c3020a7b27f8d048da9b2ec613d47c9608f655 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -23,6 +23,7 @@ from mindformers import AutoModel, PreTrainedModel from mindformers.core.context import build_mf_context from mindformers.tools.utils import is_pynative from mindspore import Tensor, mutable, ops +from mindspore.common.api import _pynative_executor from mindspore.nn.utils import no_init_parameters from vllm import envs from vllm.config import VllmConfig, get_current_vllm_config @@ -42,6 +43,15 @@ from vllm_mindspore.model_executor.models.utils import ( is_use_ringmla, make_empty_intermediate_tensors_factory) from vllm_mindspore.utils import is_310p +try: + # Need to apply dllm pd patch on vllm to use pd disagg related functions + from vllm.attention.layer import (maybe_save_kv_layer_to_connector, + wait_for_kv_layer_from_connector) + from vllm.distributed.kv_transfer import is_v1_kv_transfer_group + kv_transfer_supported = True +except: # noqa: E722 + kv_transfer_supported = False + logger = init_logger(__name__) @@ -57,6 +67,7 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): mf_config.load_checkpoint = self.get_model_path() mf_config.pretrained_model_dir = self.get_model_path() self.mf_config = mf_config + self.kv_transfer_config = vllm_config.kv_transfer_config self.mla_config = self.mf_config.get('model', None).get( 'model_config', None).get('multi_latent_attention', False) self.use_ringmla = is_use_ringmla(vllm_config, mf_config) @@ -347,7 +358,6 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): """ model_inputs, is_prefill, is_ringmla_chunked = self.prepare_inputs( input_ids, positions) - model_inputs = self.update_model_inputs(model_inputs, **kwargs) if intermediate_tensors is not None: model_inputs["hidden_states"] = \ intermediate_tensors["hidden_states"] @@ -369,7 +379,16 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): self.network.add_flags_custom_mcore(is_prefill=False) self.set_flags = (True if not self.use_ringmla else self.is_chunked) + if kv_transfer_supported and is_v1_kv_transfer_group(): + self.connector_send_kvcache() else: + if kv_transfer_supported: + if is_v1_kv_transfer_group() and self.is_prefill_task(): + self.connector_send_kvcache() + + if is_v1_kv_transfer_group() and self.is_decoder_task(): + self.connector_wait_for_kv_layer() + logger.debug("connector_wait_for_kv_layer success") hidden_states = self.network(**model_inputs) if not get_pp_group().is_last_rank: @@ -387,9 +406,6 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): return network, network.model.output_layer return network, None - def update_model_inputs(self, model_inputs, **kwargs): - return model_inputs - def compute_logits( self, hidden_states: Tensor, @@ -423,3 +439,29 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, Tensor]]): self.network.load_weights(self.mf_config.load_checkpoint) return None + + def is_decoder_task(self) -> bool: + if self.kv_transfer_config is None: + return False + + return self.kv_transfer_config.is_kv_consumer + + def is_prefill_task(self) -> bool: + if self.kv_transfer_config is None: + return False + + return self.kv_transfer_config.is_kv_producer + + def connector_send_kvcache(self): + logger.debug("reached connector_send_kvcache") + _pynative_executor.sync() + forward_context = get_forward_context() + for i in range(self.mf_model_config.num_layers): + kv_cache = self.kv_caches[i].kv_cache[ + forward_context.virtual_engine] + maybe_save_kv_layer_to_connector(str(i), kv_cache) + + def connector_wait_for_kv_layer(self): + logger.debug("reached connector_wait_for_kv_layer") + for i in range(self.mf_model_config.num_layers): + wait_for_kv_layer_from_connector("key." + str(i))