diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 2a477cbc2e822c4cbcc147c39a9fb5a040d3b582..9972302f238b493567ebca425b27647cd0aa204b 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -374,14 +374,14 @@ from vllm_mindspore.v1.worker.gpu_model_runner import _update_states vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states from vllm_mindspore.v1.worker.gpu_model_runner import ( - _allocate_kv_cache_tensors, - get_kv_cache_spec, -) + _allocate_kv_cache_tensors, get_kv_cache_spec, initialize_kv_cache_tensors) vllm.v1.worker.gpu_model_runner.GPUModelRunner._allocate_kv_cache_tensors = ( _allocate_kv_cache_tensors) vllm.v1.worker.gpu_model_runner.GPUModelRunner.get_kv_cache_spec = ( get_kv_cache_spec) +vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache_tensors = ( + initialize_kv_cache_tensors) from vllm_mindspore.v1.worker.gpu_model_runner import _reshape_kv_cache_tensors vllm.v1.worker.gpu_model_runner.GPUModelRunner._reshape_kv_cache_tensors = ( diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index b6b23c0e9600749be134f2b8cc0f1f524f43d20e..ed2b6ea15900ec30f570d6b40e9b8a5a30d3f98e 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -62,6 +62,22 @@ FORMAT_TYPE = { } +def create_kv_cache(kv_shape, dtype): + if is_310p(): + if len(kv_shape) != 4: + raise ValueError(f"Format_cast op need kv_cache shape be" + f"(batch_size, num_heads, seq_len, head_dim), " + f"but got {len(kv_shape)} dimensions: {kv_shape}") + + batch_size, num_heads, seq_len, head_dim = kv_shape + reshaped_for_nz = (batch_size, num_heads, seq_len * head_dim) + zeros_tensor = ms.mint.zeros(reshaped_for_nz, dtype=dtype) + + return ms.ops.auto_generate.format_cast(zeros_tensor, + FORMAT_TYPE['nz']) + return ms.mint.zeros(kv_shape, dtype=dtype) + + def get_valid_dtype(dtype): if isinstance(dtype, str): dtype = STR_DTYPE_TO_MS_DTYPE[dtype] diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 8bd37c26a0cf526fe8b371a2bcb911e48e24f15f..846c0268ea3c29875a6b06fd4079a526454bd8e2 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -23,20 +23,22 @@ from typing import Any, Optional import mindspore as ms import numpy as np from mindspore import Generator as msGenerator -from mindspore import Tensor, mint, mutable, ops +from mindspore import Tensor, mint, mutable from vllm.attention import AttentionType from vllm.logger import init_logger from vllm.sampling_params import SamplingType -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, - SlidingWindowSpec) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm_mindspore.model_executor.layers.rotary_embedding import ( InferMRotaryEmbedding as MRotaryEmbedding) -from vllm_mindspore.utils import (FORMAT_TYPE, get_dtype_size, get_valid_dtype, - is_310p) +from vllm_mindspore.utils import (create_kv_cache, get_dtype_size, + get_valid_dtype, is_310p) logger = init_logger(__name__) @@ -209,6 +211,66 @@ def create_block(shape, dtype, name=None, device=None): return blocks +def _allocate_nz_kv_cache_tensors(self, kv_cache_config): + """ + Initializes and reshape the KV cache buffer with the correct size. + The buffer needs to be convert to nz format for 310p. + + Args: + kv_cache_config: The KV cache config + Returns: + dict[str, Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_caches: dict[str, tuple] = {} + + layer_to_group_info = { + layer_name: (i, group.kv_cache_spec) + for i, group in enumerate(kv_cache_config.kv_cache_groups) + for layer_name in group.layer_names + } + + use_mla_op = bool( + self.vllm_config.additional_config + and self.vllm_config.additional_config.get('use_mla_op') == 1) + if use_mla_op: + logger.error("For 310p, mla kv cache not supported") + raise NotImplementedError + + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + if not kv_cache_tensor.shared_by: + continue + + rep_layer_name = kv_cache_tensor.shared_by[0] + group_idx, kv_cache_spec = layer_to_group_info[rep_layer_name] + if not isinstance(kv_cache_spec, FullAttentionSpec): + raise NotImplementedError + + attn_backend = self.attn_backends[group_idx] + target_dtype = get_valid_dtype(kv_cache_spec.dtype) + + num_blocks = kv_cache_tensor.size // kv_cache_spec.page_size_bytes + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + + reshaped_layer_tensors = [] + coef = 1 if kv_cache_spec.use_mla else 2 + for _ in range(coef): + reshaped_layer_tensors.append( + create_kv_cache(kv_cache_shape[1:], target_dtype)) + + final_kv_tuple = mutable(tuple(reshaped_layer_tensors)) + for layer_name in kv_cache_tensor.shared_by: + kv_caches[layer_name] = final_kv_tuple + + all_layers = set(layer_to_group_info.keys()) + if all_layers != set(kv_caches.keys()): + raise RuntimeError("Some layers were not initialized") + + return kv_caches + + def _allocate_kv_cache_tensors(self, kv_cache_config): """ Initializes the KV cache buffer with the correct size. The buffer needs @@ -286,9 +348,6 @@ def _reshape_kv_cache_tensors( kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - if is_310p(): - *dims, second_last, last = kv_cache_shape - kv_cache_shape = (*dims, second_last * last) try: kv_cache_stride_order = self.attn_backends[ i].get_kv_cache_stride_order() @@ -311,23 +370,48 @@ def _reshape_kv_cache_tensors( for kv_cache_raw_tensor in kv_cache_raw_tensors[layer_name]: cache_block = kv_cache_raw_tensor.view( kv_cache_shape[1:]).permute(*inv_order[1:]) - if is_310p(): - from mindspore.common.api import _pynative_executor - cache_block_nz = ops.auto_generate.format_cast( - cache_block, FORMAT_TYPE['nz']) - _pynative_executor.sync() - import gc - del cache_block - gc.collect() - kv_cache_layer.append(cache_block_nz) - else: - kv_cache_layer.append(cache_block) + kv_cache_layer.append(cache_block) kv_caches[layer_name] = mutable(tuple(kv_cache_layer)) else: raise NotImplementedError return kv_caches +def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + if is_310p(): + kv_caches = _allocate_nz_kv_cache_tensors(self, kv_cache_config) + else: + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) + + # Setup `kv_cache_config` and `kv_caches` for models + # with cross-layer KV sharing + if self.shared_kv_cache_layers: + initialize_kv_cache_for_kv_sharing( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + kv_caches, + ) + + bind_kv_cache(kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def _update_states(self, scheduler_output) -> None: """Update the cached states and the persistent batch with the scheduler output.