diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index f0e5621e54f443bd9594d83d390b399810446063..50026026dda01d4b4cb310bd63d82d9d91c83c82 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -20,9 +20,10 @@ from collections.abc import Iterable from typing import Any, Optional, Union, cast import mindspore as ms +import ms_custom_ops import numpy as np import vllm.envs as envs -from mindspore import Tensor, mutable, nn, ops +from mindspore import Tensor, mutable, nn from mindspore.common import dtype as mstype from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config @@ -93,17 +94,16 @@ class MLAAttentionWrapper(AttentionWrapper): # format_cast ops may not recycle device memory k_shape = [1, *(self.kv_shape[1:-2]), kv_lora_rank] r_shape = [1, *(self.kv_shape[1:-2]), qk_rope_head_dim] - self.kv_cache = [ - (ops.auto_generate.format_cast( - ms.mint.zeros(k_shape, dtype=kv_cache_dtype), 29), - ops.auto_generate.format_cast( - ms.mint.zeros(r_shape, - dtype=vllm_config.model_config.dtype), - 29)) - for _ in range( - vllm_config.parallel_config.pipeline_parallel_size) - ] - + # Currently, transdata has a bug and ms.jit must be added. + # Later, ms.jit will be removed. + self.kv_cache = [(ms.jit(ms_custom_ops.trans_data)( + ms.mint.zeros(k_shape, dtype=kv_cache_dtype), + transdata_type=1), ms.jit(ms_custom_ops.trans_data)( + ms.mint.zeros(r_shape, + dtype=vllm_config.model_config.dtype), + transdata_type=1)) for _ in range( + vllm_config.parallel_config.pipeline_parallel_size) + ] else: k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank] r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim] diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 54fb277cb9cd96e31c2fa395f6e4b276014ee595..cbbcf6e3413b4f502889d6c3c2b356e5f6510145 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -22,6 +22,7 @@ import traceback from typing import Any, Optional import mindspore as ms +import ms_custom_ops import numpy as np import torch from mindspore import Generator as msGenerator @@ -477,11 +478,13 @@ def _reshape_kv_cache_tensors( kv_cache_shape[1:]).permute(*inv_order[1:]) if fa3_quant: # for fa3_quant, kvcache need be nz format due to ops + # Currently, transdata has a bug and ms.jit must be + # added. Later, ms.jit will be removed. num_blocks, block_size, _, _ = cache_block.shape cache_block = ops.reshape(cache_block, (num_blocks, block_size, -1)) - cache_block_nz = ops.auto_generate.format_cast( - cache_block, 29) + cache_block_nz = ms.jit(ms_custom_ops.trans_data)\ + (cache_block, transdata_type=1) kv_cache_layer.append(cache_block_nz) else: kv_cache_layer.append(cache_block)