diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index e5263be1a41ac7e669518d568f1f9d8bbb9d81a1..95fc1f50cc94960e2e7cd764298dd040731e7a61 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -22,9 +22,11 @@ import warnings import msadapter # noqa: F401 from vllm_mindspore.msadapter_patch import patch_msadapter from vllm_mindspore.ray_patch import patch_ray +from vllm_mindspore.transformers_patch import patch_transformers patch_msadapter() patch_ray() +patch_transformers() if "vllm" in sys.modules: # Check models variable in sub process, cannot raise here. @@ -233,18 +235,15 @@ vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp from .utils import update_modules ######### for multi-model -from vllm_mindspore.multimodal.inputs import (as_kwargs, batched_reduce_data, - flat_build_elems, - flat_reduce_data, from_items, - MultiModalFieldElem, _try_stack) +from vllm_mindspore.multimodal.inputs import (as_kwargs, flat_build_elems, + from_items, MultiModalFieldElem, + _try_stack) from vllm.multimodal.inputs import MultiModalBatchedField from vllm.multimodal.inputs import MultiModalFlatField from vllm.multimodal.inputs import MultiModalKwargs -MultiModalBatchedField._reduce_data = batched_reduce_data MultiModalFlatField.build_elems = flat_build_elems -MultiModalFlatField._reduce_data = flat_reduce_data MultiModalKwargs.as_kwargs = as_kwargs MultiModalKwargs.from_items = from_items MultiModalKwargs._try_stack = _try_stack @@ -465,4 +464,8 @@ from vllm.v1.engine.processor import Processor Processor._validate_sampling_params = v1_process_validate_sampling_params Processor._validate_structured_output = v1_process_validate_structured_output +from vllm_mindspore.multimodal.processing import call_hf_processor +from vllm.multimodal.processing import InputProcessingContext +InputProcessingContext.call_hf_processor = call_hf_processor + check_ready() diff --git a/vllm_mindspore/model_executor/layers/activation.py b/vllm_mindspore/model_executor/layers/activation.py index 41fe40dea7eab94edcd14d5e33c7fab06572c3af..63bd87ef8ddd9939dda445a1644bcb5ea0f6487a 100644 --- a/vllm_mindspore/model_executor/layers/activation.py +++ b/vllm_mindspore/model_executor/layers/activation.py @@ -19,7 +19,7 @@ # limitations under the License. from mindspore import mint, nn, ops - +from vllm_mindspore.utils import LazyDict class SiluAndMul(nn.Cell): """An activation function for SwiGLU. @@ -39,3 +39,9 @@ class SiluAndMul(nn.Cell): d = x.shape[-1] // 2 gate, hidden = self.split(x, [d, d], dim=-1) return mint.mul(hidden, mint.nn.functional.silu(gate)) + + +_ACTIVATION_REGISTRY = LazyDict({ + "gelu_pytorch_tanh": + lambda: mint.nn.GELU(approximate="tanh"), +}) diff --git a/vllm_mindspore/model_executor/layers/layernorm.py b/vllm_mindspore/model_executor/layers/layernorm.py index 569d325d60ca097082efaed6fd78185ae0937749..5f1e351afe007092a10be08e4a534197344acf41 100644 --- a/vllm_mindspore/model_executor/layers/layernorm.py +++ b/vllm_mindspore/model_executor/layers/layernorm.py @@ -48,7 +48,7 @@ class RMSNorm(nn.Cell): requires_grad=False) self.rms_norm = ops.RmsNorm(eps) self.eps = eps - self.add_rms_norm = AddRmsNorm() + self.add = ops.Add() def construct( self, @@ -56,8 +56,8 @@ class RMSNorm(nn.Cell): residual: Optional[Tensor] = None ) -> Union[Tensor, tuple[Tensor, Tensor]]: if residual is not None: - output, _, residual = self.add_rms_norm(x, residual, self.weight, - self.eps) + residual = self.add(x, residual) + output = self.rms_norm(residual, self.weight)[0] return output, residual output = self.rms_norm(x, self.weight)[0] return output diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 2d52f7096e63643058b75285aa6fc3f6fd367b2e..2a8364a3a3e37e24aea1991872bccf47e5807d6e 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -24,11 +24,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import math from typing import Any, Optional, Union -import mindspore import numpy as np +import mindspore as ms from mindspore import Tensor, mint, nn, ops from mindspore.common import dtype as mstype from mindspore.ops.auto_generate.gen_ops_prim import SliceExt @@ -36,7 +37,17 @@ from transformers import PretrainedConfig from vllm.config import get_current_vllm_config from vllm_mindspore.model_executor.utils import get_model_context - +from vllm_mindspore.model_executor.models.vision import ( + get_llm_pos_ids_for_vision +) + +def _get_feat_extract_output_lengths(input_lengths: ms.Tensor): + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + ) + return feat_lengths, output_lengths def _apply_rotary_emb( x: Tensor, @@ -268,8 +279,9 @@ class MRotaryEmbedding(RotaryEmbedding): max_position_embeddings: int, base: float, is_neox_style: bool, - dtype: mindspore.Type, + dtype: ms.Type, mrope_section: Optional[list[int]] = None, + mrope_interleaved: bool = False, ) -> None: # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get @@ -282,13 +294,59 @@ class MRotaryEmbedding(RotaryEmbedding): if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 + self.mrope_interleaved = mrope_interleaved + if self.mrope_interleaved: + assert len(self.mrope_section) == 3 + mrope_section_np = np.array(self.mrope_section, dtype=np.int64) + sec_total = mrope_section_np.sum() + h_sec = np.array(list(range(1, self.mrope_section[1] * 3, 3))) + sec_total + w_sec = np.array(list(range(2, self.mrope_section[2] * 3, 3))) + 2 * sec_total + select_index = np.arange(sec_total, dtype=np.int64) + select_index[1 : mrope_section[1] * 3 : 3] = h_sec + select_index[2 : mrope_section[2] * 3 : 3] = w_sec + self.rope_select_index = ms.from_numpy(select_index) + else: + assert len(self.mrope_section) == 3 + mrope_section_np = np.array(self.mrope_section, dtype=np.int64) + sec_total = mrope_section_np.sum() + sec_cu = mrope_section_np.cumsum() + h_sec = np.arange(sec_cu[0], sec_cu[1]) + sec_total + w_sec = np.arange(sec_cu[1], sec_cu[2]) + 2 * sec_total + select_index = np.arange(sec_total, dtype=np.int64) + select_index[sec_cu[0] : sec_cu[1]] = h_sec + select_index[sec_cu[1] : sec_cu[2]] = w_sec + self.rope_select_index = ms.from_numpy(select_index) + + if self.is_neox_style and self.rotary_dim == self.head_size: + self.rotary_embedding_op = ops.ApplyRotaryPosEmb(2) + + def apply_interleaved_rope(self, x: Tensor, mrope_section: list[int]) -> Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + x = ops.transpose(x, (1, 0, 2)) + x = mint.flatten(x, start_dim=1) + x_t = mint.index_select(x, -1, self.rope_select_index) + return x_t + + def apply_no_interleaved_rope(self, x: Tensor, mrope_section: list[int]) -> Tensor: + """Apply non-interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + non-interleaved [TTTHHHWWW]. + """ + x = ops.transpose(x, (1, 0, 2)) + x = mint.flatten(x, start_dim=1) + x_t = mint.index_select(x, -1, self.rope_select_index) + return x_t + def construct( self, - positions: mindspore.Tensor, - query: mindspore.Tensor, - key: mindspore.Tensor, + positions: ms.Tensor, + query: ms.Tensor, + key: ms.Tensor, batch_valid_length: Tensor = None, - ) -> tuple[mindspore.Tensor, mindspore.Tensor]: + ) -> tuple[ms.Tensor, ms.Tensor]: """ Args: positions: @@ -308,14 +366,19 @@ class MRotaryEmbedding(RotaryEmbedding): cos_sin = self.cos_sin_cache[positions] cos, sin = ops.chunk(cos_sin, 2, axis=-1) if positions.ndim == 2: - cos_l = mint.split(cos, self.mrope_section, dim=-1) - sin_l = mint.split(sin, self.mrope_section, dim=-1) - cos, sin = (), () - for i in range(len(self.mrope_section)): # type: ignore[arg-type] - cos += (cos_l[i][i], ) - sin += (sin_l[i][i], ) - cos = mint.cat(cos, dim=-1) - sin = mint.cat(sin, dim=-1) + if self.mrope_interleaved: + cos = self.apply_interleaved_rope(cos, self.mrope_section) + sin = self.apply_interleaved_rope(sin, self.mrope_section) + else: + cos = self.apply_no_interleaved_rope(cos, self.mrope_section) + sin = self.apply_no_interleaved_rope(sin, self.mrope_section) + + if self.is_neox_style and self.rotary_dim == self.head_size: + freqs_cos = mint.cat((cos, cos), dim=-1) + freqs_sin = mint.cat((sin, sin), dim=-1) + query, key = self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, + batch_valid_length) + return query, key query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -336,11 +399,13 @@ class MRotaryEmbedding(RotaryEmbedding): def get_input_positions( input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], mindspore.Tensor], - video_grid_thw: Union[list[list[int]], mindspore.Tensor], + image_grid_thw: Union[list[list[int]], ms.Tensor], + video_grid_thw: Union[list[list[int]], ms.Tensor], second_per_grid_ts: Optional[list[float]] = None, context_len: int = 0, seq_len: Optional[int] = None, + audio_feature_lengths: Optional[ms.Tensor] = None, + use_audio_in_video: bool = False, ) -> tuple[list[list[int]], int]: """Get mrope input positions and delta value.""" @@ -353,6 +418,8 @@ class MRotaryEmbedding(RotaryEmbedding): second_per_grid_ts=second_per_grid_ts, context_len=context_len, seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video ) return llm_positions.tolist(), mrope_position_delta @@ -362,13 +429,47 @@ class MRotaryEmbedding(RotaryEmbedding): cls, input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], mindspore.Tensor], - video_grid_thw: Union[list[list[int]], mindspore.Tensor], - second_per_grid_ts=None, + image_grid_thw: Union[list[list[int]], ms.Tensor], + video_grid_thw: Union[list[list[int]], ms.Tensor], + second_per_grid_ts: Optional[list[float]] = None, context_len: int = 0, seq_len: Optional[int] = None, - ) -> tuple[mindspore.Tensor, int]: + audio_feature_lengths: Optional[ms.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[ms.Tensor, int]: """Get mrope input positions and delta value.""" + from vllm.transformers_utils.config import thinker_uses_mrope + if thinker_uses_mrope(hf_config): + return cls._qwen3_omni_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + return cls._qwen3_vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + elif hf_config.model_type in ["glm4v", "glm4v_moe"]: + return cls._glm4v_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, hf_config=hf_config, @@ -379,6 +480,544 @@ class MRotaryEmbedding(RotaryEmbedding): seq_len=seq_len, ) + @classmethod + def _qwen3_vl_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | Tensor, + video_grid_thw: list[list[int]] | Tensor, + context_len: int = 0, + seq_len: int | None = None, + second_per_grid_ts: list[float] | None = None, + ) -> tuple[Tensor, int]: + """Get mrope input positions and delta value.""" + + video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + input_tokens_tensor = ms.tensor(input_tokens) + vision_start_indices = ms.ops.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + mint.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + mint.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + mint.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + mint.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + mint.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + mint.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = mint.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta + + @classmethod + def _glm4v_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], Tensor], + video_grid_thw: Union[list[list[int]], Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[Tensor, int]: + """Get mrope input positions and delta value for GLM4V (NumPy only).""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list[np.ndarray] = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, Tensor): + image_grid_thw = image_grid_thw.asnumpy().tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = int(llm_pos_ids_list[-1].max() + 1) if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t = int(t) + llm_grid_h = int(h // spatial_merge_size) + llm_grid_w = int(w // spatial_merge_size) + + t_indices, h_indices, w_indices = np.meshgrid( + np.arange(llm_grid_t, dtype=np.int64), + np.arange(llm_grid_h, dtype=np.int64), + np.arange(llm_grid_w, dtype=np.int64), + indexing='ij' + ) + + stacked = np.stack([ + t_indices.ravel(), + h_indices.ravel(), + w_indices.ravel() + ], axis=0) + st_idx + + llm_pos_ids_list.append(stacked) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t = int(t) + llm_grid_h = int(h // spatial_merge_size) + llm_grid_w = int(w // spatial_merge_size) + + for t_idx in range(llm_grid_t): + t_indices, h_indices, w_indices = np.meshgrid( + np.arange(t_idx, dtype=np.int64), + np.arange(llm_grid_h, dtype=np.int64), + np.arange(llm_grid_w, dtype=np.int64), + indexing='ij' + ) + + stacked = np.stack([ + t_indices.ravel(), + h_indices.ravel(), + w_indices.ravel() + ], axis=0) + st_idx + + llm_pos_ids_list.append(stacked) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = int(end_idx - start_idx) + base = np.arange(text_len, dtype=np.int64) + stacked = np.tile(base, (3, 1)) + st_idx + llm_pos_ids_list.append(stacked) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + base = np.arange(text_len, dtype=np.int64) + llm_pos_ids_list.append(np.tile(base, (3, 1))) + + llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = int(llm_positions.max() + 1 - len(input_tokens)) + return Tensor(llm_positions), mrope_position_delta + + @classmethod + def _qwen3_omni_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | Tensor | None, + video_grid_thw: list[list[int]] | Tensor | None, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[Tensor, int]: + config = hf_config.thinker_config + if isinstance(image_grid_thw, list): + image_grid_thw = Tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = Tensor(video_grid_thw) + input_ids = Tensor(input_tokens) + if input_ids is None or input_ids.ndim != 1: + raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") + + seq_len = input_ids.shape[0] + if audio_feature_lengths is not None and not isinstance( + audio_feature_lengths, Tensor + ): + audio_feature_lengths = Tensor( + audio_feature_lengths, dtype=ms.int64 + ) + if second_per_grid_ts is None: + if video_grid_thw is not None and video_grid_thw.numel() > 0: + second_per_grids = mint.ones( + video_grid_thw.shape[0], dtype=ms.float32 + ) + else: + second_per_grids = Tensor([], dtype=ms.float32) + else: + second_per_grids = Tensor(second_per_grid_ts, dtype=ms.float32) + + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + audio_token_id = config.audio_token_id + vision_start_token_id = config.vision_start_token_id + audio_start_token_id = config.audio_start_token_id + position_id_per_seconds = config.position_id_per_seconds + + vision_start_indices = ops.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + if vision_start_indices.numel() > 0: + vision_tokens = input_ids[vision_start_indices + 1] + else: + vision_tokens = mint.empty((0,), dtype=input_ids.dtype) + audio_nums = mint.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + + llm_pos_ids_list: list[Tensor] = [] + st = 0 + image_idx = 0 + video_idx = 0 + audio_idx = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501 + multimodal_nums = ( + image_nums + audio_nums + if use_audio_in_video + else image_nums + video_nums + audio_nums + ) # noqa: E501 + + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + if (image_token_id in input_tokens or video_token_id in input_tokens) and ( + remain_videos > 0 or remain_images > 0 + ): + ed_vision_start = input_tokens.index(vision_start_token_id, st) + else: + ed_vision_start = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio_start = input_tokens.index(audio_start_token_id, st) + else: + ed_audio_start = len(input_tokens) + 1 + min_ed = min(ed_vision_start, ed_audio_start) + + if min_ed == ed_audio_start: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + mint.arange(text_len, dtype=ms.int64) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + mint.arange(bos_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + _, audio_len = _get_feat_extract_output_lengths( + audio_feature_lengths[audio_idx] + ) + llm_pos_ids = ( + mint.arange(audio_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + mint.arange(eos_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == image_token_id + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + mint.arange(text_len, dtype=ms.int64) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + mint.arange(bos_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = mint.arange(grid_t.item()) * position_id_per_seconds + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + mint.arange(eos_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == video_token_id + and not use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + mint.arange(text_len, dtype=ms.int64) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + mint.arange(bos_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + mint.arange(grid_t.item()) + * float(second_per_grids[video_idx].item()) + * position_id_per_seconds + ) + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + mint.arange(eos_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + elif ( + min_ed == ed_vision_start + and ed_vision_start + 1 == ed_audio_start + and use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + mint.arange(text_len, dtype=ms.int64) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + bos_block = ( + mint.arange(bos_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(bos_block) + llm_pos_ids_list.append(bos_block) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + _, audio_len = _get_feat_extract_output_lengths( + audio_feature_lengths[audio_idx] + ) + audio_llm_pos_ids = ( + mint.arange(audio_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + mint.arange(grid_t.item()) + * float(second_per_grids[video_idx].item()) + * position_id_per_seconds + ) + video_llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_data_index, audio_data_index = 0, 0 + while ( + video_data_index < video_llm_pos_ids.shape[-1] + and audio_data_index < audio_llm_pos_ids.shape[-1] + ): + if ( + video_llm_pos_ids[0][video_data_index] + <= audio_llm_pos_ids[0][audio_data_index] + ): + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_data_index + 1 + ] + ) + video_data_index += 1 + else: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_data_index + 1 + ] + ) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_llm_pos_ids.shape[-1] + ] + ) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_llm_pos_ids.shape[-1] + ] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + eos_block = ( + mint.arange(eos_len, dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(eos_block) + llm_pos_ids_list.append(eos_block) + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501 + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + mint.arange(text_len.item(), dtype=ms.int64).view(1, -1).expand(3, -1) + + st_idx + ) + + llm_positions = mint.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + if llm_positions.shape[1] != seq_len: + raise RuntimeError("Position ids length mismatch with input ids length") + + mrope_position_delta = llm_positions.max() + 1 - seq_len + return llm_positions, mrope_position_delta.item() + @classmethod def _vl_get_input_positions_tensor( cls, @@ -398,12 +1037,12 @@ class MRotaryEmbedding(RotaryEmbedding): tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) - if isinstance(image_grid_thw, mindspore.Tensor): + if isinstance(image_grid_thw, ms.Tensor): image_grid_thw = image_grid_thw.tolist() - if isinstance(video_grid_thw, mindspore.Tensor): + if isinstance(video_grid_thw, ms.Tensor): video_grid_thw = video_grid_thw.tolist() - input_tokens_tensor = mindspore.Tensor(input_tokens) + input_tokens_tensor = ms.Tensor(input_tokens) vision_start_indices = ops.argwhere( input_tokens_tensor == vision_start_token_id).squeeze(1) vision_tokens = input_tokens_tensor[vision_start_indices + 1] @@ -503,15 +1142,14 @@ class MRotaryEmbedding(RotaryEmbedding): ] @staticmethod - def get_next_input_positions_tensor( - mrope_position_delta: int, - context_len: int, - seq_len: int, - ) -> mindspore.Tensor: - return mint.arange( + def get_next_input_positions_tensor(out: ms.Tensor, out_offset: int, + mrope_position_delta: int, + context_len: int, num_new_tokens: int): + values = mint.arange( int(mrope_position_delta + context_len), - int(mrope_position_delta + seq_len), + int(mrope_position_delta + context_len + num_new_tokens), ).broadcast_to((3, -1)) + out[:, out_offset:out_offset + num_new_tokens] = values class InferMRotaryEmbedding(InferRotaryEmbedding): @@ -530,7 +1168,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): max_position_embeddings: int, base: float, is_neox_style: bool, - dtype: mindspore.Type, + dtype: ms.Type, mrope_section: Optional[list[int]] = None, ) -> None: # In Qwen2.5-VL, the maximum index value is related to the duration of @@ -554,11 +1192,11 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): def construct( # type: ignore[override] self, - positions: mindspore.Tensor, - query: mindspore.Tensor, - key: mindspore.Tensor, + positions: ms.Tensor, + query: ms.Tensor, + key: ms.Tensor, batch_valid_length: Tensor = None, - ) -> tuple[mindspore.Tensor, mindspore.Tensor]: + ) -> tuple[ms.Tensor, ms.Tensor]: """ Args: positions: @@ -801,26 +1439,16 @@ def get_rope( original_max_position) elif scaling_type == "default": if "mrope_section" in rope_scaling: - if is_neox_style: - rotary_emb = InferMRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - dtype, - mrope_section=rope_scaling["mrope_section"], - ) - else: - rotary_emb = MRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - dtype, - mrope_section=rope_scaling["mrope_section"], - ) + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + ) else: raise NotImplementedError elif scaling_type == "yarn": diff --git a/vllm_mindspore/model_executor/models/interfaces.py b/vllm_mindspore/model_executor/models/interfaces.py index cee654577463c3c51b94cd0130df84e947e02afc..8c23e2b8355d39f39b610b0fdddd0a9adc6979c9 100644 --- a/vllm_mindspore/model_executor/models/interfaces.py +++ b/vllm_mindspore/model_executor/models/interfaces.py @@ -44,6 +44,7 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + merge_by_field_config: ClassVar[bool] = False def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: """ diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 04d32360b23812f70a23250741169ed2edfae6ff..00a00ea7f664a326a4e545e6ac9f717be3fa91f1 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -621,7 +621,7 @@ class NativeModel(MsModelBase): model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, intermediate_tensors, inputs_embeds) - + model_inputs.update(kwargs) # for dummy_attention_metadata if is_prefill and not self.has_prefill_warmup: self.has_prefill_warmup = True diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py index 12589a9daa3b205955b689da11e749cb343a5f2a..047c3e58d983fcbec6bd7a45e9007a6b96fd9eb0 100644 --- a/vllm_mindspore/model_executor/models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -53,7 +53,6 @@ from vllm.distributed import get_pp_group, \ get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -181,7 +180,7 @@ Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, Qwen2_5_VLVideoEmbeddingInputs] # For profile run -_MAX_FRAMES_PER_VIDEO = 16 +_MAX_FRAMES_PER_VIDEO = 14 # === Vision Inputs === # @@ -204,7 +203,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): min_pixels=min_pixels, max_pixels=max_pixels, size=size, - use_fast=kwargs.get("use_fast")), + # TODO: mindone not support fast processor yet. + use_fast=kwargs.pop("use_fast", False)), **kwargs, ) @@ -339,6 +339,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): max_image_size, _ = self._get_vision_info( image_width=9999999, image_height=9999999, + num_frames=1, image_processor=None, ) return max_image_size @@ -352,10 +353,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): image_processor=None, ) - def _get_max_video_frames(self, max_tokens: int) -> int: + def _get_max_video_frames(self, + max_tokens: int, + start_num_frames: int = 1) -> int: target_width, target_height = self.get_image_size_with_most_features() - num_frames = 0 + num_frames = start_num_frames while True: next_num_frames = num_frames + 1 @@ -377,6 +380,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): self, seq_len: int, mm_counts: Mapping[str, int], + max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO ) -> int: max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) @@ -385,7 +389,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_frames_per_video) return max(max_frames_per_video, 1) @@ -428,7 +432,8 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): min_pixels=min_pixels, max_pixels=max_pixels, size=size, - use_fast=kwargs.get("use_fast")), + # TODO: mindone not support fast processor yet. + use_fast=kwargs.pop("use_fast", False)), **kwargs, ) @@ -531,18 +536,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] def _get_data_parser(self) -> MultiModalDataParser: return Qwen2VLMultiModalDataParser() - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - ) -> BatchFeature: - return self.info.ctx.call_hf_processor( - self.info.get_hf_processor(**mm_kwargs), - dict(text=prompt, **mm_data), - self.info._get_image_processor_kwargs(**mm_kwargs), - ) - def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -563,7 +556,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, ms.Tensor) num_tokens = int(grid_thw.prod()) // merge_length @@ -608,7 +602,8 @@ class _Qwen2VLMultiModalProcessor(Qwen2VLMultiModalProcessor): merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, ms.Tensor) num_tokens = int(grid_thw.prod()) // merge_length @@ -1587,10 +1582,8 @@ class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal): def compute_logits( self, hidden_states: ms.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[ms.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights( diff --git a/vllm_mindspore/model_executor/models/qwen3_vl.py b/vllm_mindspore/model_executor/models/qwen3_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..467b95d1ebb144704a9b3d8bdff81c351e3b7507 --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen3_vl.py @@ -0,0 +1,1651 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" + +import math +from collections.abc import Callable, Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Optional + +import mindspore as ms +import mindspore.mint.nn.functional as F +import numpy as np +from mindspore import Parameter, Tensor, mint, mutable, nn +from mindspore.common import dtype as mstype +from mindspore.ops.operations.nn_ops import FlashAttentionScore +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + smart_resize as image_smart_resize) +from transformers.models.qwen3_vl import (Qwen3VLProcessor, + Qwen3VLVideoProcessor) +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig +from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( + smart_resize as video_smart_resize) +from transformers.video_utils import VideoMetadata +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from vllm_mindspore.model_executor.layers.activation import ( + _ACTIVATION_REGISTRY) +from vllm_mindspore.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import ( + LogitsProcessor) +from vllm_mindspore.model_executor.layers.rotary_embedding import ( + _apply_rotary_emb) +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + default_weight_loader) +from vllm_mindspore.model_executor.models.attention_mask import ( + MultiModalLowerTriangularMask) +from vllm_mindspore.model_executor.models.interfaces import ( + MultiModalEmbeddings, SupportsMultiModal) +from vllm_mindspore.model_executor.models.model_base import (AttentionWrapper, + NativeModel) +from vllm_mindspore.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionAttention, Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs, Qwen2VLProcessingInfo) +from vllm_mindspore.model_executor.models.qwen3 import (Qwen3ForCausalLM, + Qwen3Model) +from vllm_mindspore.model_executor.models.utils import ( + WeightsMapper, _merge_multimodal_embeddings, maybe_prefix) +from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE + +try: + from ms_custom_ops import apply_rotary_pos_emb_atb + is_custom_rope_available = True +except ImportError: + is_custom_rope_available = False + +logger = init_logger(__name__) + +# Official recommended max pixels is 24576 * 32 * 32 +_MAX_FRAMES_PER_VIDEO = 24576 + + +class Qwen3_VisionAttention(Qwen2_5_VisionAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.flash_attention_score = FlashAttentionScore( + head_num=self.num_attention_heads_per_partition, + scale_value=1 / math.sqrt(self.hidden_size_per_attention_head), + input_layout="TH") + self.apply_rope = self._custom_ops_rope if is_custom_rope_available \ + else self._native_rope + + def _native_rope(self, q, k, cos, sin, batch_valid_length): + seq_length = q.shape[0] + q = q.reshape(seq_length, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + k = k.reshape(seq_length, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q = _apply_rotary_emb(q, cos, sin, True) + k = _apply_rotary_emb(k, cos, sin, True) + q = q.reshape( + seq_length, self.num_attention_heads_per_partition * + self.hidden_size_per_attention_head) + k = k.reshape( + seq_length, self.num_attention_heads_per_partition * + self.hidden_size_per_attention_head) + return q, k + + def _custom_ops_rope(self, q, k, cos, sin, batch_valid_length): + cos = mint.cat((cos, cos), dim=-1) + sin = mint.cat((sin, sin), dim=-1) + q, k = apply_rotary_pos_emb_atb(q, k, cos, sin, batch_valid_length, 2, + 0) + return q, k + + def construct(self, x: Tensor, batch_valid_length: Tensor, + position_embeddings: tuple[ms.Tensor, ms.Tensor], + q_seq_lens: Tensor) -> Tensor: + qkv, _ = self.qkv(x) + q, k, v = mint.split( + qkv, (self.num_attention_heads_per_partition * self.head_dim, + self.num_attention_heads_per_partition * self.head_dim, + self.num_attention_heads_per_partition * self.head_dim), -1) + cos, sin = position_embeddings + origin_dtype = q.dtype + + q, k = self.apply_rope(q, k, cos, sin, batch_valid_length) + + # q/k reshape to TH + q = q.astype(origin_dtype) + k = k.astype(origin_dtype) + + _, _, _, context_layer = self.flash_attention_score( + q, + k, + v, + None, + None, + None, + None, + None, + batch_valid_length, + q_seq_lens, + ) + output, _ = self.proj(context_layer) + return output + + +class Qwen3_VisionPatchEmbed(nn.Cell): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + self.dtype = get_current_vllm_config().model_config.dtype + + # Use Dense layer instead of Conv3d for MindSpore compatibility + self.proj = ms.nn.Dense(temporal_patch_size * patch_size * patch_size * + in_channels, + hidden_size, + has_bias=True, + dtype=self.dtype) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + x = self.proj(x) + return x + + +class Qwen3_VisionMLP(nn.Cell): + + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu, + quant_config=None, + prefix: str = "", + ): + super().__init__() + self.linear_fc1 = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1") + self.linear_fc2 = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2") + self.act_fn = act_fn + + def construct(self, x: ms.Tensor): + mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return mlp_output + + +class Qwen3_VisionBlock(nn.Cell): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu, + norm_layer: Callable[[int], nn.Cell] | None = None, + quant_config=None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(mint.nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen3_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.mlp = Qwen3_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def construct( + self, + x: ms.Tensor, + batch_valid_length: ms.Tensor, + position_embeddings: ms.Tensor, + q_seq_lens: ms.Tensor, + ) -> ms.Tensor: + x = x + self.attn(self.norm1(x), batch_valid_length, + position_embeddings, q_seq_lens) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3_VisionPatchMerger(nn.Cell): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Callable[[int], nn.Cell] | None = None, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + quant_config=None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + + self.use_postshuffle_norm = use_postshuffle_norm + if self.use_postshuffle_norm: + context_dim = self.hidden_size + + if norm_layer is None: + norm_layer = partial(mint.nn.LayerNorm, eps=1e-6) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = norm_layer(context_dim) + self.linear_fc1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc1", + ) + self.act_fn = nn.GELU() + self.linear_fc2 = RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc2", + ) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.hidden_size)) + else: + x = self.norm(x).view(-1, self.hidden_size) + + x_parallel, _ = self.linear_fc1(x) + x_parallel = self.act_fn(x_parallel) + out, _ = self.linear_fc2(x_parallel) + return out + + +class Qwen3_VisionTransformer(nn.Cell): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config=None, + prefix: str = "", + ) -> None: + super().__init__() + self.vision_config = vision_config + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.num_position_embeddings = vision_config.num_position_embeddings + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.num_grid_per_side = int(self.num_position_embeddings**0.5) + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + self.pos_embed = mint.nn.Embedding(self.num_grid_per_side**2, + self.hidden_size, + dtype=self.dtype) + + norm_layer = partial(mint.nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.CellList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) for layer_idx in range(vision_config.depth) + ]) + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + if self.deepstack_visual_indexes is not None: + self.deepstack_merger_list = nn.CellList([ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + ) for layer_idx in range(len(self.deepstack_visual_indexes)) + ]) + + @property + def dtype(self) -> ms.dtype: + return self.patch_embed.proj.weight.dtype + + def construct( + self, + x: ms.Tensor, + batch_valid_length: ms.Tensor, + q_seq_lens: ms.Tensor, + rotary_pos_emb: ms.Tensor, + pos_embeds: ms.Tensor, + ) -> ms.Tensor: + hidden_states = x.astype(self.dtype) + hidden_states = self.patch_embed(hidden_states) + + hidden_states = hidden_states + pos_embeds + seq_len, _ = x.shape + rotary_pos_emb = rotary_pos_emb.astype(hidden_states.dtype) + emb = rotary_pos_emb + position_embeddings = (mint.cos(emb), mint.sin(emb)) + + hidden_states_list = [] + deepstack_visual_indexes = self.deepstack_visual_indexes + + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, batch_valid_length, + position_embeddings, q_seq_lens) + if (deepstack_visual_indexes is not None + and layer_num in deepstack_visual_indexes): + hidden_states_list.append(hidden_states) + + hidden_states = self.merger(hidden_states) + + # processing deepstack + if deepstack_visual_indexes is not None: + processed_hidden_states_list = [hidden_states] + for idx, x in enumerate(hidden_states_list): + x = self.deepstack_merger_list[idx](x) + processed_hidden_states_list.append(x) + # we cat the original visual features and deepstack features + # along the feature dim + hidden_states = mint.cat( + processed_hidden_states_list, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, Tensor]], + params_dict: dict[str, Parameter]) -> set[str]: + stacked_params_mapping = [ + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ] + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + if "patch_embed.proj.weight" in name: + loaded_weight = loaded_weight[:] + loaded_weight = loaded_weight.reshape( + loaded_weight.shape[0], -1) + param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def set_model_inputs(self): + x_dtype = get_current_vllm_config().model_config.dtype + dyn_x = ms.Tensor(shape=[None, None], dtype=x_dtype) + dyn_batch_valid_length = ms.Tensor(shape=[None], dtype=ms.int32) + dyn_q_seq_lens = ms.Tensor(shape=[None], dtype=ms.int32) + dyn_rotary_pos_emb = ms.Tensor(shape=[None, None], dtype=ms.float32) + dyn_pos_emb = ms.Tensor(shape=[None, None], dtype=x_dtype) + + self.set_inputs(dyn_x, dyn_batch_valid_length, dyn_q_seq_lens, + dyn_rotary_pos_emb, dyn_pos_emb) + + +class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3VLConfig) + + def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor: + return self.ctx.get_hf_processor( + Qwen3VLProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + + def get_tokenizer(self): + return self.ctx.tokenizer + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_video_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).video_processor + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 2, + do_resize: bool = True, + image_processor=None, + ) -> tuple[ImageSize, int]: + if image_processor is None and num_frames > 1: + image_processor = self.get_video_processor() + elif image_processor is None: + image_processor = self.get_image_processor() + + is_video = isinstance(image_processor, Qwen3VLVideoProcessor) + + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + temporal_patch_size = vision_config.temporal_patch_size + + if do_resize: + if is_video: + smart_resize = video_smart_resize + extra_kwargs = { + "num_frames": num_frames, + "temporal_factor": temporal_patch_size, + } + else: + smart_resize = image_smart_resize + extra_kwargs = {} + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.size["shortest_edge"], + max_pixels=image_processor.size["longest_edge"], + **extra_kwargs, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + padded_num_frames = num_frames + num_frames % temporal_patch_size + + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (merge_size**2) + + return preprocessed_size, num_vision_tokens + + def _get_max_video_frames(self, + max_tokens: int, + start_num_frames: int = 2) -> int: + return super()._get_max_video_frames(max_tokens, + start_num_frames=start_num_frames) + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + return super().get_num_frames_with_most_features( + seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + video_soft_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + # NOTE: By default in Qwen3-VL, one video token is converted to + # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 + formatted_video_soft_tokens = video_soft_tokens * 12.5 + return int(formatted_video_soft_tokens) + + def _calculate_timestamps(self, indices: list[int] | Tensor, + video_fps: float, merge_size: int): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + # don't update metadata's frames_indices directly + indices = indices + [indices[-1] + ] * (merge_size - len(indices) % merge_size) + timestamps = [idx / video_fps for idx in indices] + timestamps = [(timestamps[i] + timestamps[i + merge_size - 1]) / 2 + for i in range(0, len(timestamps), merge_size)] + return timestamps + + def _get_video_second_idx( + self, + metadata: dict[str, Any], + out_item: MultiModalKwargs, + do_sample_frames: bool | None = None, + sampled_fps: float | None = None, + ) -> list[int]: + video_processor = self.get_video_processor() + merge_size = video_processor.merge_size + indices = metadata["frames_indices"] + + # metadata["fps"] refers to the true fps of the input video. + video_fps = metadata["fps"] + if do_sample_frames is None: + do_sample_frames = metadata.get("do_sample_frames", False) + + # If video frames are sampled in HF processor (instead of vLLM + # video loader), we need to re-calculate the indices from original + # metadata. + if do_sample_frames: + # here video_fps is the fps of the sampled video, and + # metadata["fps"] refers to the fps of the original video. + sampled_fps = sampled_fps if sampled_fps else video_processor.fps + total_num_frames = metadata["total_num_frames"] + num_frames = int(total_num_frames / metadata["fps"] * sampled_fps) + num_frames = min( + min( + max(num_frames, video_processor.min_frames), + video_processor.max_frames, + ), + total_num_frames, + ) + indices = (np.linspace(0, total_num_frames - 1, + num_frames).round().astype(int).tolist()) + timestamps = self._calculate_timestamps(indices, video_fps, merge_size) + return timestamps + + +class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_token = "<|vision_start|><|image_pad|><|vision_end|>" + video_token = "<|vision_start|><|video_pad|><|vision_end|>" + + return image_token * num_images + video_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts) + + target_video_size, _ = self.info._get_vision_info( + image_width=target_width, + image_height=target_height, + num_frames=target_num_frames, + image_processor=self.info.get_video_processor(), + ) + # NOTE: we need to do this check here since Qwen3-VL resizes video + # frames depending on how many frames there are. + width, height = target_video_size.width, target_video_size.height + return { + "image": + self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + ), + "video": + self._get_dummy_videos( + width=target_video_size.width, + height=target_video_size.height, + num_frames=target_num_frames, + num_videos=num_videos, + ), + } + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ): + video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) + video_items = [] + for i in range(num_videos): + video_metadata = { + "fps": 2.0, + "duration": num_frames / 2.0, + "total_num_frames": num_frames, + "frames_indices": [i for i in range(num_frames)], + "video_backend": "opencv", + "do_sample_frames": False, + } + video_item = (video.copy(), video_metadata) + video_items.append(video_item) + return video_items + + +class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo] + ): + + def _get_data_parser(self) -> MultiModalDataParser: + return MultiModalDataParser(video_needs_metadata=True) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + processor = self.info.get_hf_processor(**mm_kwargs) + + # Separate video processing from image processing. Because the videos + # are processed into serval image patches + if ("videos" in mm_data and isinstance(mm_data["videos"], list) + and len(mm_data["videos"]) > 0): + video_grid_thw_lst = [] + pixel_values_videos_lst = [] + + for item_idx, item in enumerate(mm_data.pop("videos", [])): + video_array, metadata = item + + # NOTE: @JJJYmmm new attr metadata.frames_indices indicates + # the sampled frames indices of pre-sampled videos, which is + # used to calculate the timestamps. Make sure that + # do_sample_frames in mm_kwargs is false for presampled videos. + + # NOTE: a copy of is created to update do_sample_frames, + # otherwise mm_hash for the object will be incorrect. + video_mm_kwargs = dict(**mm_kwargs) + if "do_sample_frames" not in video_mm_kwargs: + # qwen_vl_utils already has "do_sample_frames" in + # mm_kwargs, don't overwrite it. + video_mm_kwargs["do_sample_frames"] = metadata.get( + "do_sample_frames", False) + + metadata = VideoMetadata(**{ + k: metadata[k] + for k in metadata if k != "do_sample_frames" + }) + + video_mm_data = dict() + video_mm_data["videos"] = [[video_array]] + video_mm_data["video_metadata"] = [[metadata]] + + video_outputs = super()._call_hf_processor( + prompt="<|vision_start|><|video_pad|><|vision_end|>", + mm_data=video_mm_data, + mm_kwargs=video_mm_kwargs, + ) + input_ids = video_outputs.pop("input_ids") + video_placeholder = processor.tokenizer.batch_decode( + input_ids)[0] + prompt = prompt.replace( + "<|vision_start|><|video_pad|><|vision_end|>", + video_placeholder, + 1, + ) + + video_grid_thw_lst.append(video_outputs["video_grid_thw"]) + pixel_values_videos_lst.append( + video_outputs["pixel_values_videos"]) + video_outputs = dict( + pixel_values_videos=mint.cat(pixel_values_videos_lst), + video_grid_thw=mint.cat(video_grid_thw_lst), + ) + else: + video_outputs = dict() + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + combined_outputs = dict( + processed_outputs, + **video_outputs, + ) + return BatchFeature(combined_outputs) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_thw = hf_inputs.get("image_grid_thw", mint.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", mint.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + hf_config = self.info.get_hf_config() + + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + vision_end_token_id = hf_config.vision_end_token_id + + merge_length = image_processor.merge_size**2 + + def get_image_replacement_qwen3vl(item_idx: int): + out_item = out_mm_kwargs["image"][item_idx] + grid_thw = out_item["image_grid_thw"].data + + assert isinstance(grid_thw, Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [hf_processor.image_token_id] * num_tokens + + def get_video_replacement_qwen3vl(item_idx: int): + out_item = out_mm_kwargs["video"][item_idx] + grid_thw = out_item["video_grid_thw"].data + + assert isinstance(grid_thw, Tensor) + + video, metadata = mm_items["video"][item_idx] + do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") + sampled_fps = hf_processor_mm_kwargs.get("fps") + if is_list_of(sampled_fps, float): + sampled_fps = sampled_fps[item_idx] + out_item = None + timestamps = self.info._get_video_second_idx( + metadata, out_item, do_sample_frames, sampled_fps) + + assert len(timestamps) == grid_thw[0], ( + f"The timestamps length({len(timestamps)}) should be equal " + f"video length ({grid_thw[0]}).") + + frames_idx_token = [ + tokenizer.encode(f"<{curr_time:.1f} seconds>", + add_special_tokens=False) + for curr_time in timestamps + ] + num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + placeholder = [] + for frame_idx in frames_idx_token: + placeholder.extend(frame_idx) + placeholder.extend([vision_start_token_id] + + [video_token_id] * num_tokens_per_frame + + [vision_end_token_id]) + return PromptUpdateDetails.select_token_id(placeholder, + video_token_id) + + return [ + PromptReplacement( + modality="image", + target=hf_processor.image_token, + replacement=get_image_replacement_qwen3vl, + ), + # NOTE: We match string on purpose since searching sequence of + # token ids takes more time. + PromptReplacement( + modality="video", + target="<|vision_start|><|video_pad|><|vision_end|>", + replacement=get_video_replacement_qwen3vl, + ), + ] + + +class Qwen3LLMModel(Qwen3Model): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + deepstack_layers: int = 0): + super().__init__(vllm_config=vllm_config, prefix=prefix) + if not get_pp_group().is_first_rank: + assert self.start_layer >= len( + vllm_config.model_config.hf_config.vision_config. + deepstack_visual_indexes), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)") + self.deepstack_layers = deepstack_layers + + def construct( + self, + input_ids: Tensor, + positions: Tensor, + key_caches: list[Tensor], + value_caches: list[Tensor], + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + deepstack_input_embeds: Optional[Mapping[str, Tensor]] = None, + ) -> ms.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + key_caches[i - self.start_layer], + value_caches[i - self.start_layer], + slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables, residual) + + if deepstack_input_embeds is not None and i in range( + self.deepstack_layers): + hidden_states = mint.add(hidden_states, + deepstack_input_embeds[i]) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3LLMForCausalLM(Qwen3ForCausalLM): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + deepstack_layers: int = 0): + super(Qwen3ForCausalLM, self).__init__(vllm_config=vllm_config) + config = vllm_config.model_config.hf_config.text_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.model = Qwen3LLMModel(vllm_config=vllm_config, + prefix=prefix, + deepstack_layers=deepstack_layers) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class Qwen3VLForConditionalGeneration( + NativeModel, + SupportsMultiModal, +): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + #TODO: To support 'mm_encoder_tp_mode == "data"', + # Linear in layers should be refactored first. + supports_encoder_tp_data = False + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.vision_config = config.vision_config + self.text_config = config.text_config + + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + if not multimodal_config.get_limit_per_prompt("image") and \ + not multimodal_config.get_limit_per_prompt("video"): + self.visual = None + else: + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + self.visual.set_model_inputs() + self.visual.construct = ms.jit(function=self.visual, + jit_level='O0') + + self.use_deepstack = hasattr(config.vision_config, + 'deepstack_visual_indexes') + self.deepstack_num_level = len( + config.vision_config.deepstack_visual_indexes + ) if self.use_deepstack else 0 + + self.language_model = Qwen3LLMForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + deepstack_layers=self.deepstack_num_level) + + self.model = self.language_model.model + self.lm_head = self.language_model.lm_head + self.common_preprocess(vllm_config, prefix) + + self.model.embed_tokens._set_jit_graph_name("prefill") + self.model.embed_tokens.phase = "prefill" + dyn_input_ids = ms.Tensor(shape=[None], dtype=ms.int32) + self.model.embed_tokens.set_inputs(dyn_input_ids) + self.model.embed_tokens.construct = ms.jit( + function=self.model.embed_tokens, jit_level='O0') + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + # register buffer for deepstack + if self.use_deepstack and self.visual is not None: + self.deepstack_input_embeds = [ + mint.zeros( + (vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size), + dtype=self.model_config.dtype) + for _ in range(self.deepstack_num_level) + ] + else: + self.deepstack_input_embeds = None + self.visual_dim = config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + head_dim = (self.vision_config.hidden_size // + self.vision_config.num_heads) + self.rotary_pos_emb_full = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + def common_preprocess(self, vllm_config, prefix=""): + self.set_modules({ + "model.visual": self.visual, + "model.language_model": self.language_model.model, + "lm_head": self.language_model.lm_head + }) + self.casual_mask = MultiModalLowerTriangularMask( + dtype=self.model_config.dtype, + max_model_len=self.model_config.max_model_len) + self.kv_caches = [ + AttentionWrapper() + for i in range(self.text_config.num_hidden_layers) + ] + + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(self.text_config.num_hidden_layers): + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] + + def _get_deepstack_input_embeds(self, + num_tokens: int) -> IntermediateTensors: + # get deepstack_input_embeds from buffer, and clear the buffer + deepstack_input_embeds = \ + [self.deepstack_input_embeds[idx][:num_tokens] + for idx in range(self.deepstack_num_level)] + deepstack_input_embeds = mint.stack(deepstack_input_embeds, dim=0) + return deepstack_input_embeds + + def _set_deepstack_input_embeds(self, + deepstack_input_embeds: ms.Tensor) -> None: + # set deepstack_input_embeds to buffer + num_tokens = deepstack_input_embeds.shape[1] + if num_tokens > self.deepstack_input_embeds[0].shape[0]: + self.deepstack_input_embeds = [ + mint.zeros( + (num_tokens, self.config.text_config.hidden_size), + dtype=self.deepstack_input_embeds[0].dtype, + ) for _ in range(self.deepstack_num_level) + ] + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[ + idx][:num_tokens] = deepstack_input_embeds[idx] + + def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: + # clear deepstack_input_embeds in buffer + if num_tokens > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].zero_() + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> Tensor: + if not isinstance(mm_input, (Tensor, list)): + raise ValueError( + f"Incorrect type of {name}. Got type: {type(mm_input)}") + if isinstance(mm_input, Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return mm_input.reshape(-1, mm_input.shape[-1]) + else: + return mint.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Qwen2_5_VLImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Qwen2_5_VLVideoInputs | None: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, ms.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input(self, image_input) -> tuple[Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"].type(self.visual.dtype) + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + pos_emb = self.fast_pos_embed_interpolate(grid_thw.tolist()) + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + grid_thw_1 = grid_thw.index_select(1, ms.Tensor([1])).reshape(-1) + grid_thw_2 = grid_thw.index_select(1, ms.Tensor([2])).reshape(-1) + grid_thw_0 = grid_thw.index_select(1, ms.Tensor([0])).reshape(-1) + batch_valid_length = mint.repeat_interleave( + grid_thw_1 * grid_thw_2, grid_thw_0).astype(ms.int32) + image_embeds = self.visual(pixel_values, batch_valid_length, + batch_valid_length, rotary_pos_emb, pos_emb) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Qwen2_5_VLVideoInputs) -> tuple[Tensor, ...]: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + pos_emb = self.fast_pos_embed_interpolate(grid_thw.tolist()) + pixel_values = video_input["pixel_values"].type(self.visual.dtype) + grid_thw_1 = grid_thw.index_select(1, ms.Tensor([1])).reshape(-1) + grid_thw_2 = grid_thw.index_select(1, ms.Tensor([2])).reshape(-1) + grid_thw_0 = grid_thw.index_select(1, ms.Tensor([0])).reshape(-1) + batch_valid_length = mint.repeat_interleave( + grid_thw_1 * grid_thw_2, grid_thw_0).astype(ms.int32) + image_embeds = self.visual(pixel_values, batch_valid_length, + batch_valid_length, rotary_pos_emb, pos_emb) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + for input_key in kwargs: + if (input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality): + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if (input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality): + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + return mm_input_by_modality + + def get_language_model(self): + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> MultiModalEmbeddings | None: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += tuple(video_embeddings) + return multimodal_embeddings + + def _compute_deepstack_embeds( + self, + inputs_embeds: ms.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + is_multimodal: ms.Tensor, + ) -> tuple[ms.Tensor, MultiModalEmbeddings]: + visual_lens = [len(x) for x in multimodal_embeddings] + multimodal_embeddings_cat = mint.cat(multimodal_embeddings, dim=0) + + ( + multimodal_embeddings_main, + multimodal_embeddings_multiscale, + ) = mint.split( + multimodal_embeddings_cat, + [self.visual_dim, self.multiscale_dim], + dim=-1, + ) + + multimodal_embeddings = mint.split(multimodal_embeddings_main, + visual_lens, + dim=0) + multimodal_embeddings_multiscale = mint.split( + multimodal_embeddings_multiscale, visual_lens, dim=0) + + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.shape[0], + self.deepstack_num_level * inputs_embeds.shape[1]) + + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_multimodal, + ) + deepstack_input_embeds = deepstack_input_embeds.view( + inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) + deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) + + return deepstack_input_embeds, multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: ms.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + ) -> ms.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + placeholder_token_id = [ + self.config.image_token_id, self.config.video_token_id + ] + is_multimodal = ms.numpy.isin(input_ids, placeholder_token_id) + + if self.use_deepstack: + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + else: + deepstack_input_embeds = None + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + if deepstack_input_embeds is not None: + self._set_deepstack_input_embeds(deepstack_input_embeds) + + return inputs_embeds + + def forward( + self, + input_ids: ms.Tensor, + positions: ms.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: ms.Tensor | None = None, + **kwargs: object, + ) -> ms.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + if (self.use_deepstack and inputs_embeds is not None + and get_pp_group().is_first_rank): + deepstack_input_embeds = self._get_deepstack_input_embeds( + inputs_embeds.shape[0]) + else: + deepstack_input_embeds = None + + hidden_states = self.exec_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + # args for deepstack + deepstack_input_embeds=deepstack_input_embeds, + ) + + if inputs_embeds is not None and get_pp_group().is_first_rank: + self._clear_deepstack_input_embeds(inputs_embeds.shape[0]) + + return hidden_states + + def compute_logits( + self, + hidden_states: ms.Tensor, + ) -> ms.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + ms.Tensor]]) -> set[str]: + params_dict = self.get_params_dict() + loaded_param = set() + visual_load = set() + text_load = set() + for name, weight in weights: + if "model.visual." in name: + visual_load.update( + self.visual.load_weights([(name, weight)], params_dict)) + elif "model.language_model." in name: + text_load.update( + self.model.load_weights([(name, weight)], params_dict)) + else: + # Handle other weights + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_param.add(name) + loaded_param.update(visual_load) + loaded_param.update(text_load) + return loaded_param + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.merger", + tower_model="visual.", + ) + + def rot_pos_emb(self, grid_thw: ms.Tensor) -> ms.Tensor: + spatial_merge_size = self.vision_config.spatial_merge_size + pos_ids = [] + for t, h, w in grid_thw: + t, h, w = int(t.item()), int(h.item()), int(w.item()) + hpos_ids = mint.arange(h).unsqueeze(1).expand((-1, w)) + wpos_ids = mint.arange(w).unsqueeze(0).expand((h, -1)) + + hpos_ids = hpos_ids.reshape( + h // spatial_merge_size, + spatial_merge_size, + w // spatial_merge_size, + spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // spatial_merge_size, + spatial_merge_size, + w // spatial_merge_size, + spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + mint.tile(mint.stack([hpos_ids, wpos_ids], dim=-1), (t, 1))) + pos_ids = mint.cat(pos_ids, dim=0) + max_grid_size = int(grid_thw[:, 1:].max().item()) + rotary_pos_emb_full = self.rotary_pos_emb_full(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, + grid_thw: list[list[int]]) -> ms.Tensor: + num_grid_per_side = self.visual.num_grid_per_side + m_size = self.visual.spatial_merge_size + hidden_dim = self.visual.pos_embed.embedding_dim + + outputs = [] + for t, h, w in grid_thw: + h_idxs = mint.linspace(0, + num_grid_per_side - 1, + h, + dtype=ms.float32) + w_idxs = mint.linspace(0, + num_grid_per_side - 1, + w, + dtype=ms.float32) + + h_floor = h_idxs.astype(ms.int64) + w_floor = w_idxs.astype(ms.int64) + h_ceil = mint.clamp(h_floor + 1, 0, num_grid_per_side - 1) + w_ceil = mint.clamp(w_floor + 1, 0, num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = mint.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = mint.meshgrid(h_floor, + w_floor, + indexing="ij") + h_ceil_grid, w_ceil_grid = mint.meshgrid(h_ceil, + w_ceil, + indexing="ij") + h_floor_grid_idx = h_floor_grid * num_grid_per_side + h_ceil_grid_idx = h_ceil_grid * num_grid_per_side + + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - dw_grid + w11 + + idx00 = h_floor_grid_idx + w_floor_grid + idx01 = h_floor_grid_idx + w_ceil_grid + idx10 = h_ceil_grid_idx + w_floor_grid + idx11 = h_ceil_grid_idx + w_ceil_grid + + indices = mint.stack([idx00, idx01, idx10, idx11], + dim=0).reshape(4, -1) + weights = mint.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.astype(self.visual.dtype) + + embeds = self.visual.pos_embed(indices) + weighted_embeds = embeds * weights + p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) + combined = p0 + p1 + p2 + p3 + + combined = combined.view(h * w, hidden_dim) + repeated = combined.unsqueeze(0).expand(t, -1, -1) + repeated = repeated.view(t, h // m_size, m_size, w // m_size, + m_size, hidden_dim) + repeated = repeated.permute(0, 1, 3, 2, 4, + 5).reshape(-1, hidden_dim) + outputs.append(repeated) + + return mint.cat(outputs, dim=0) + + def set_model_inputs(self, + input_ids=None, + position_ids=None, + intermediate_tensors=None, + inputs_embeds=None): + if input_ids is None: + dyn_input_ids = None + else: + dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, + dtype=mstype.int32) + + if position_ids is None: + dyn_position_ids = None + else: + dyn_position_ids = ms.Tensor(shape=[None] * position_ids.ndim, + dtype=mstype.int32) + + if inputs_embeds is None: + dyn_inputs_embeds = None + else: + dyn_inputs_embeds = ms.Tensor(shape=[None] * inputs_embeds.ndim, + dtype=inputs_embeds.dtype) + + if intermediate_tensors is None: + dyn_intermediate_tensors = None + else: + dyn_intermediate_tensors = ms.Tensor( + shape=[None] * intermediate_tensors.ndim, + dtype=intermediate_tensors.dtype) + + block_size = self.cache_config.block_size + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + kv_cache_shape = (None, block_size, num_kv_heads, head_size) + + kv_cache_dtype = (self.model_config.dtype + if self.cache_config.cache_dtype == "auto" else + self.cache_config.cache_dtype) + if kv_cache_dtype in STR_DTYPE_TO_MS_DTYPE: + kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype] + + num_layers = self.model_config.get_num_layers(self.parallel_config) + + dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) + dyn_value_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) + dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) + dyn_value_caches = mutable( + [dyn_value_cache for _ in range(num_layers)]) + + dyn_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) + dynamic_attention_mask = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) + dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) + dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) + dyn_deepstack_input_embeds = Tensor(shape=[None, None, None], + dtype=self.model_config.dtype) + + self.ready_model.set_inputs( + dyn_input_ids, dyn_position_ids, dyn_key_caches, dyn_value_caches, + dyn_slot_mapping, dynamic_attention_mask, dyn_batch_valid_length, + dyn_q_seq_lens, dyn_block_tables, dyn_intermediate_tensors, + dyn_inputs_embeds, dyn_deepstack_input_embeds) + + dynamic_hidden_states = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + self.ready_lm_head.set_inputs(dynamic_hidden_states) diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index 8f025436d3bda684f1530a8ffd0795a6f7f48f13..d48428a44d18d02ccfd4eb9561a28c37797e9dfa 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -68,6 +68,7 @@ _NATIVE_MODELS = { "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), + "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), } _MINDFORMERS_MODELS = { diff --git a/vllm_mindspore/model_executor/models/vision.py b/vllm_mindspore/model_executor/models/vision.py new file mode 100644 index 0000000000000000000000000000000000000000..804e2310d7908da1e033e8512f01e4a1cee4684a --- /dev/null +++ b/vllm_mindspore/model_executor/models/vision.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/model_executor/models/vision.py +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mindspore import Tensor, mint + +def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: Tensor, + grid_ws: Tensor, +) -> Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx].item() // spatial_merge_size + llm_grid_w = grid_ws[vision_idx].item() // spatial_merge_size + h_index = ( + mint.arange(llm_grid_h) + .view(1, -1, 1) + .expand(len(t_index), -1, llm_grid_w) + .flatten() + ) + w_index = ( + mint.arange(llm_grid_w) + .view(1, 1, -1) + .expand(len(t_index), llm_grid_h, -1) + .flatten() + ) + t_index_tensor = ( + Tensor(t_index) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .long() + .flatten() + ) + _llm_pos_ids = mint.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = mint.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids diff --git a/vllm_mindspore/multimodal/inputs.py b/vllm_mindspore/multimodal/inputs.py index 64f604c803a653451c3144c23c7d31d8ea5e0553..d8add9dbfe32e07c4fa1d188e2ea0b971f025388 100644 --- a/vllm_mindspore/multimodal/inputs.py +++ b/vllm_mindspore/multimodal/inputs.py @@ -121,48 +121,6 @@ def flat_build_elems( field_factory = self._field_factory(modality=modality, key=key) return [field_factory(data[cast(slice, s)]) for s in self.slices] - -def batched_reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: - # NOTE: vLLM-MindSpore Plugin: - # Currently mindspore does not support operating tensors in a - # multi-threaded environment, so convert tensors to numpy. - if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): - if len(batch) == 1: - # An optimization when `batch` contains only one tensor: - # - produce exactly same result as `torch.stack(batch)` - # - will achieve zero-copy if the tensor is contiguous - return mindspore.from_numpy(np.expand_dims(batch[0].numpy(), 0)) - first_shape = batch[0].shape - if all(elem.shape == first_shape for elem in batch): - return mindspore.from_numpy(np.stack([b.numpy() for b in batch])) - - return batch - - -def flat_reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: - # NOTE: vLLM-MindSpore Plugin: - # Currently mindspore does not support operating tensors in a - # multi-threaded environment, so convert tensors to numpy. - if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): - if len(batch) == 1: - # An optimization when `batch` contains only one tensor: - # - produce exactly same result as `torch.concat(batch)` - # - will achieve zero-copy if the tensor is contiguous - return mindspore.from_numpy(batch[0].numpy()) - - def _expect_same_shape(tensor: torch.Tensor): - return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:] - - first_shape = _expect_same_shape(batch[0]) - - if all(_expect_same_shape(elem) == first_shape for elem in batch): - return mindspore.from_numpy( - np.concatenat([b.numpy() for b in batch], axis=self.dim)) - - assert self.dim == 0, "dim == 0 is required for nested list" - return [e for elem in batch for e in elem] - - @staticmethod def _try_stack(nested_tensors: NestedTensors, pin_memory: bool = False) -> NestedTensors: diff --git a/vllm_mindspore/inputs/registry.py b/vllm_mindspore/multimodal/processing.py similarity index 82% rename from vllm_mindspore/inputs/registry.py rename to vllm_mindspore/multimodal/processing.py index c6a0be6493b988bfee3623dfb87da903c2c944b3..3cb7175a5d249cf07f2a6e13b98a07dd04a19bf0 100644 --- a/vllm_mindspore/inputs/registry.py +++ b/vllm_mindspore/multimodal/processing.py @@ -15,8 +15,9 @@ # limitations under the License. """Adaption for input processor.""" -from vllm.inputs.registry import (BatchFeature, InputProcessingContext, - Mapping, ProcessorMixin) +from collections.abc import Mapping +from transformers import BatchFeature, ProcessorMixin +from vllm.multimodal.processing import InputProcessingContext origin_call_hf_processor = InputProcessingContext.call_hf_processor @@ -34,7 +35,8 @@ def call_hf_processor( def _wrapper(func): def _inner(*args, **kwargs): - kwargs["return_tensors"] = "np" + # origin return tensors is 'pt', to use mindone, should be 'ms' + kwargs["return_tensors"] = "ms" return func(*args, **kwargs) return _inner diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 5da6d0d55340b03044fd96f3bbce29319095c198..ffbcf3013de730ed836873d63c842884279615c2 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -108,16 +108,6 @@ class AscendPlatform(Platform): model_config = vllm_config.model_config model_config.disable_cascade_attn = True - # Cache between p0 and p1 effective only one-on-one situations. In data - # parallelelism, it is a one-to-many scenario, cache should be disabled. - if (model_config.multimodal_config is not None - and not model_config.disable_mm_preprocessor_cache - and parallel_config.data_parallel_size > 1): - model_config.multimodal_config.disable_mm_preprocessor_cache = True - logger.info( - "Disable mm preprocessor cache for data parallel size %d.", - parallel_config.data_parallel_size) - @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, diff --git a/vllm_mindspore/scripts.py b/vllm_mindspore/scripts.py index ed557bb56643024cf1cc6673e663472faf4f772a..1b11cfbb8dd5d185bb7f9ab3b26103ff02de5f78 100644 --- a/vllm_mindspore/scripts.py +++ b/vllm_mindspore/scripts.py @@ -18,7 +18,6 @@ import logging import os # It's before the vllm import, so vllm.logger cannot be used here. -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/vllm_mindspore/transformers_patch.py b/vllm_mindspore/transformers_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..781fafbbbddd5cd8ed951d7da87c3740b9dd4152 --- /dev/null +++ b/vllm_mindspore/transformers_patch.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import importlib +import sys +import types +import warnings + +os.environ["USE_TORCH"] = "FALSE" +os.environ["USE_TF"] = "FALSE" + +def _patch_processing_module(): + import mindone.transformers.models as mo_models + import transformers.models as tf_models + + # Gather all modules in mindone.transformers.models + for attr in dir(mo_models): + if attr.startswith("_"): + continue + try: + mo_submod = getattr(mo_models, attr) + tf_submod = getattr(tf_models, attr, None) + except Exception: + continue + if not isinstance(mo_submod, types.ModuleType): + continue + + # Get all "processing_"* or "*processing*" or "image_processing_"* in mindone + for sub_attr in dir(mo_submod): + if ("processing" in sub_attr): + try: + mo_proc_mod = getattr(mo_submod, sub_attr) + except Exception: + continue + + # Try to get matching transformers module + tf_modname = f"transformers.models.{attr}.{sub_attr}" + mo_modname = f"mindone.transformers.models.{attr}.{sub_attr}" + if tf_modname in sys.modules and mo_modname in sys.modules: + sys.modules[tf_modname] = sys.modules[mo_modname] + else: + # Try to import if not already loaded + try: + tf_mod = importlib.import_module(tf_modname) + mo_mod = importlib.import_module(mo_modname) + sys.modules[tf_modname] = sys.modules[mo_modname] + except Exception: + continue + +def patch_transformers(): + try: + import mindone + except ImportError: + warnings.warn("mindone.transformers not installed, " + "skip patching transformers.") + return + + import transformers + transformers.utils.is_accelerate_available = \ + lambda *args, **kwargs: False + from mindone.transformers import ProcessorMixin + transformers.ProcessorMixin = ProcessorMixin + transformers.processing_utils.ProcessorMixin = ProcessorMixin + + from mindone.transformers import AutoProcessor + transformers.AutoProcessor = AutoProcessor + transformers.models.auto.processing_auto.AutoProcessor = AutoProcessor + + from mindone.transformers import AutoImageProcessor + transformers.AutoImageProcessor = AutoImageProcessor + + _patch_processing_module() diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 7cba47a2021605c9cf70832a6b4b550226a89cef..59121b2b19e7c134a257a283190e1460d56f7a62 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -28,7 +28,8 @@ import tempfile import uuid from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Callable, Generator, Generic, List, Mapping, + Optional, Tuple, Union) import numpy as np import torch @@ -458,3 +459,27 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +# Adapted from: https://stackoverflow.com/a/47212782/5082708 +class LazyDict(Mapping[str, T], Generic[T]): + + def __init__(self, factory: dict[str, Callable[[], T]]): + self._factory = factory + self._dict: dict[str, T] = {} + + def __getitem__(self, key: str) -> T: + if key not in self._dict: + if key not in self._factory: + raise KeyError(key) + self._dict[key] = self._factory[key]() + return self._dict[key] + + def __setitem__(self, key: str, value: Callable[[], T]): + self._factory[key] = value + + def __iter__(self): + return iter(self._factory) + + def __len__(self): + return len(self._factory) diff --git a/vllm_mindspore/v1/serial_utils.py b/vllm_mindspore/v1/serial_utils.py index 07b49df717435f19247509b91023f1d123d1071a..f67d32b955f90fe1ad681e981d7fc9985cad10b1 100644 --- a/vllm_mindspore/v1/serial_utils.py +++ b/vllm_mindspore/v1/serial_utils.py @@ -5,9 +5,10 @@ from typing import Any, Union import mindspore as ms import numpy as np -import torch from msgspec import msgpack +np_bfloat16 = "bfloat16" + mstype_str_to_np_type = { "Bool": np.bool_, "Int8": np.int8, @@ -20,10 +21,11 @@ mstype_str_to_np_type = { "Uint64": np.uint64, "Float16": np.float16, "Float32": np.float32, + "BFloat16": np_bfloat16, } -def _decode_tensor(self, arr: Any) -> torch.Tensor: +def _decode_tensor(self, arr: Any) -> ms.Tensor: dtype, shape, data = arr # Copy from inline representation, to decouple the memory storage # of the message from the original buffer. And also make Torch @@ -32,17 +34,17 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: else bytearray(data) if not buffer: # torch.frombuffer doesn't like empty buffers assert 0 in shape - return torch.empty(shape, dtype=dtype) + return ms.mint.empty(shape, dtype=dtype) # Create uint8 array - arr = torch.frombuffer(buffer, dtype=torch.uint8) + arr = np.frombuffer(buffer, dtype=np.uint8) # Convert back to proper shape & type - arr = arr.numpy().view(dtype=mstype_str_to_np_type[dtype]).reshape(shape) + arr = arr.view(dtype=mstype_str_to_np_type[dtype]).reshape(shape) tensor = ms.from_numpy(arr) return tensor def _encode_tensor( - self, obj: torch.Tensor + self, obj: ms.Tensor ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None # view the tensor as a contiguous 1D array of bytes diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index a162f852214fa7b0e2ac50d70f453864428a1bd2..5407a5d969401755e187f571145bd4a8b2a333f6 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -1038,7 +1038,7 @@ def _calc_mrope_positions(self, scheduler_output): src_start = int(num_computed_tokens) src_end = int(num_computed_tokens + prompt_part_len) - self.mrope_positions_cpu[:, dst_start:dst_end] = \ + self.mrope_positions.cpu[:, dst_start:dst_end] = \ req.mrope_positions[:,src_start:src_end] mrope_pos_ptr += prompt_part_len @@ -1048,15 +1048,13 @@ def _calc_mrope_positions(self, scheduler_output): dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + completion_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - MRotaryEmbedding.get_next_input_positions_tensor( - req.mrope_position_delta, - context_len=num_computed_tokens + - prompt_part_len, - seq_len=num_computed_tokens + - prompt_part_len + - completion_part_len, - ) + MRotaryEmbedding.get_next_input_positions_tensor( + out=self.mrope_positions.cpu, + out_offset=dst_start, + mrope_position_delta=req.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) mrope_pos_ptr += completion_part_len