diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 57cf46b53221acc402bc12022c21d023a5651aa5..41d7062a83d686230b09d98756f8ed76bf929662 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -37,11 +37,6 @@ from transformers import PretrainedConfig from vllm.config import get_current_vllm_config from vllm_mindspore.model_executor.utils import get_model_context -try: - import ms_custom_ops - ms_custom_ops_avail = True -except ImportError: - ms_custom_ops_avail = False def _apply_rotary_emb( x: Tensor, @@ -283,6 +278,14 @@ class MRotaryEmbedding(RotaryEmbedding): if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 + try: + import ms_custom_ops + self.ms_custom_ops_avail = True + self.apply_rotary_pos_emb_v3 = ms_custom_ops.apply_rotary_pos_emb_v3 + except ImportError: + self.ms_custom_ops_avail = False + self.apply_rotary_pos_emb_v3 = None + def construct( self, positions: mindspore.Tensor, @@ -322,7 +325,7 @@ class MRotaryEmbedding(RotaryEmbedding): query = query.view(num_tokens, -1, self.head_size) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - if ms_custom_ops_avail is False: + if self.ms_custom_ops_avail is False: query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) @@ -333,7 +336,7 @@ class MRotaryEmbedding(RotaryEmbedding): key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = mint.cat((key_rot, key_pass), dim=-1).view(key_shape) else: - query, key = ms_custom_ops.apply_rotary_pos_emb_v3(query, key, cos, sin, "BSH", "interleave") + query, key = self.apply_rotary_pos_emb_v3(query, key, cos, sin, "BSH", "interleave") query = query.view(query_shape) key = key.view(key_shape) return query, key diff --git a/vllm_mindspore/model_executor/models/glm4_1v.py b/vllm_mindspore/model_executor/models/glm4_1v.py index d1eb246aa6b12246ee44ea5f5b3e2da5af304d7d..7cf42f63b11b5d57bc6bf62f535cc1d0f804fcc9 100644 --- a/vllm_mindspore/model_executor/models/glm4_1v.py +++ b/vllm_mindspore/model_executor/models/glm4_1v.py @@ -89,7 +89,6 @@ from vllm_mindspore.model_executor.models.qwen2_5_vl import _qwen2vl_field_confi import mindspore as ms from mindspore import mint, ops from mindspore.ops.operations.nn_ops import FlashAttentionScore -import ms_custom_ops from vllm_mindspore.utils import is_310p from mindspore.common.api import _pynative_executor _pynative_executor.set_enable_grad(False) @@ -250,6 +249,9 @@ class Glm4vVisionAttention(nn.Cell): input_layout="TH") self.dtype = get_current_vllm_config().model_config.dtype + import ms_custom_ops + self.apply_rotary_pos_emb_ext = ms_custom_ops.apply_rotary_pos_emb_ext + def construct( self, x: Tensor, @@ -272,9 +274,9 @@ class Glm4vVisionAttention(nn.Cell): cos, sin = position_embeddings origin_dtype = q.dtype - q, k = ms_custom_ops.apply_rotary_pos_emb_ext(q.astype(ms.float32), - k.astype(ms.float32), - cos, sin, "BSND", "half") + q, k = self.apply_rotary_pos_emb_ext(q.astype(ms.float32), + k.astype(ms.float32), + cos, sin, "BSND", "half") # q/k reshape to TH q = q.astype(origin_dtype) @@ -435,6 +437,9 @@ class Glm4vVisionEmbeddings(nn.Cell): self.position_embedding = mint.nn.Embedding(self.num_positions, self.embed_dim, dtype=self.dtype) # (576, 1536) + import ms_custom_ops + self.grid_sample = ms_custom_ops.grid_sample + def construct(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> Tensor: pos_embed_weight = self.position_embedding.weight @@ -480,7 +485,7 @@ class Glm4vVisionEmbeddings(nn.Cell): dim=-1).unsqueeze(0).unsqueeze(2)) # Perform bicubic interpolation - interpolated_embed_fp32 = ms_custom_ops.grid_sample( + interpolated_embed_fp32 = self.grid_sample( pos_embed_2d, grid, # mode="bicubic", diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 8e39c3d1560d7de3b653ac22eefd5fdc51cbd3ea..fdff477a3b8b6a9661703ad062f1131de4264e5d 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -105,6 +105,19 @@ class AscendPlatform(Platform): model_config = vllm_config.model_config model_config.disable_cascade_attn = True + # Cache between p0 and p1 effective only one-on-one situations. In data + # parallelelism, it is a one-to-many scenario, cache should be disabled. + if (not model_config.disable_mm_preprocessor_cache + and parallel_config.data_parallel_size > 1): + if model_config.multimodal_config is None: + raise RuntimeError( + "For disable_mm_preprocessor_cache, multimodal_config " + "should not be None!") + model_config.multimodal_config.disable_mm_preprocessor_cache = True + logger.info( + "Disable mm preprocessor cache for data parallel size %d.", + parallel_config.data_parallel_size) + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla):