From 21a7fe8b2014b16b2c50c18ace61695ec8eb8c3d Mon Sep 17 00:00:00 2001 From: wusimin Date: Thu, 27 Feb 2025 18:03:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0MLA=20attention=E7=9A=84?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm_mindspore/__init__.py | 2 + vllm_mindspore/attention/backends/ms_attn.py | 56 +++++++++++++++++ vllm_mindspore/attention/selector.py | 18 +++++- vllm_mindspore/worker/cache_engine.py | 65 ++++++++++++++++++-- 4 files changed, 133 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 672be957ab..d86791a551 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -124,10 +124,12 @@ from vllm_mindspore.worker.cache_engine import ( ms_allocate_kv_cache, ms_swap_in, ms_swap_out, + cache_engine_init ) import vllm.worker.cache_engine +vllm.worker.cache_engine.CacheEngine.__init__ = cache_engine_init vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index 6febd68b7b..2b5684165a 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -459,3 +459,59 @@ class MsAttentionImpl(AttentionImpl): NOTE: It in-place updates the output tensor. """ pass + +class MLABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_impl_cls() -> Type["AttentionImpl"]: + return MsAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MSAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]: + return MsAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["AttentionState"]: + return MsAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (1,num_blocks, block_size, 1,head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + swap_cache(src_key_cache, dst_key_cache, src_to_dst) + + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + blocks_to_copy = src_to_dists.asnumpy().tolist() + for kv_cache in kv_caches: + npu_key_block, npu_value_block = kv_cache + for src, dst in blocks_to_copy: + npu_key_block[dst, :] = npu_key_block[0, :] + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] \ No newline at end of file diff --git a/vllm_mindspore/attention/selector.py b/vllm_mindspore/attention/selector.py index 508b7f8b69..c92961ccba 100644 --- a/vllm_mindspore/attention/selector.py +++ b/vllm_mindspore/attention/selector.py @@ -40,6 +40,7 @@ def which_attn_to_use( block_size: int, is_attention_free: bool, use_v1: bool = False, + use_mla: bool = False ) -> _Backend: """Returns which flash attention backend to use.""" selected_backend = _Backend.FLASH_ATTN @@ -53,7 +54,8 @@ def which_attn_to_use( "MindSpore donot support %s attention backend now!" % str(backend_by_env_var) ) - + if use_mla: + return "MLA_ATTN" # get device-specific default attn_backend default_backend = current_platform.get_default_attn_backend(selected_backend) if default_backend is not None: @@ -71,6 +73,7 @@ def _cached_get_attn_backend( is_attention_free: bool, is_blocksparse: bool = False, use_v1: bool = False, + use_mla: bool = False ) -> Type[AttentionBackend]: if is_blocksparse: logger.warning( @@ -78,7 +81,7 @@ def _cached_get_attn_backend( ) backend = which_attn_to_use( - head_size, dtype, kv_cache_dtype, block_size, is_attention_free, use_v1 + head_size, dtype, kv_cache_dtype, block_size, is_attention_free, use_v1, use_mla ) if backend == _Backend.FLASH_ATTN: logger.info("Using Flash Attention backend.") @@ -87,6 +90,15 @@ def _cached_get_attn_backend( ) return MsAttentionBackend + + elif backend == "MLA_ATTN": + logger.info("Using ML Attention backend.") + from vllm_mindspore.attention.backends.ms_attn import ( + MLABackend + ) + + return MLABackend + else: raise ValueError("Invalid attention backend.") @@ -98,6 +110,7 @@ def get_ms_attn_backend( block_size: int, is_attention_free: bool, is_blocksparse: bool = False, + use_mla: bool = False ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -112,4 +125,5 @@ def get_ms_attn_backend( is_attention_free=is_attention_free, is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, ) diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index 714a22d672..a7e4024d3e 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -23,9 +23,11 @@ from vllm.logger import init_logger logger = init_logger(__name__) from vllm_mindspore.utils import MsKVCache, get_valid_dtype - +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType) +from vllm.attention import get_attn_backend import mindspore as ms - +import os def create_block(shape, dtype, name=None, device=None): from mindspore.ops.function.array_func import empty as empty_tensor @@ -39,10 +41,10 @@ def ms_allocate_kv_cache( device: str, ) -> List[MsKVCache]: """Allocates KV cache on the specified device.""" - # kv_cache_shape = self.attn_backend.get_kv_cache_shape( - # num_blocks, self.block_size, self.num_kv_heads, self.head_size - # ) - kv_cache_shape = (2, 512, 16, 1, 576) + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size + ) + # kv_cache_shape = (2, 512, 16, 1, 576) kv_cache: List[MsKVCache] = [] self.dtype = get_valid_dtype(self.dtype) @@ -71,3 +73,54 @@ def ms_swap_out(self, src_to_dst: ms.Tensor) -> None: self.attn_backend.swap_blocks( self.gpu_cache[i], self.cpu_cache[i], src_to_dst, True ) + +def cache_engine_init( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + ) -> None: + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + self.device_config = device_config + + self.head_size = model_config.get_head_size() + # Models like Jamba, have mixed typed layers, E.g Mamba + self.num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + if self.num_gpu_blocks: + self.num_gpu_blocks //= parallel_config.pipeline_parallel_size + self.num_cpu_blocks = cache_config.num_cpu_blocks + if self.num_cpu_blocks: + self.num_cpu_blocks //= parallel_config.pipeline_parallel_size + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Confirm whether MLA or not + if (os.getenv("VLLM_MODEL_BACKEND") == "MindFormer" + and model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLM"): + is_mla = True + print("DEEPSEEKV3 USING MLA!") + else: + is_mla = False + # Get attention backend. + self.attn_backend = get_attn_backend(self.head_size, + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + model_config.is_attention_free, + use_mla=is_mla) + + # Initialize the cache. + self.gpu_cache = self._allocate_kv_cache( + self.num_gpu_blocks, self.device_config.device_type) + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") \ No newline at end of file -- Gitee