From 80121f6dbabdc6c5d3ef5a07adc5b068c8920c05 Mon Sep 17 00:00:00 2001 From: WanYidong Date: Sat, 20 Sep 2025 11:45:24 +0800 Subject: [PATCH] bugfix: mindspore does not supprot thread parallelism --- vllm_mindspore/__init__.py | 12 ++-- vllm_mindspore/multimodal/inputs.py | 95 ++++++++++++++++++++++++++--- vllm_mindspore/v1/serial_utils.py | 6 +- 3 files changed, 97 insertions(+), 16 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index aa77b7e3..8af01398 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -345,16 +345,20 @@ from vllm.inputs.registry import InputProcessingContext InputProcessingContext.call_hf_processor = call_hf_processor -from vllm_mindspore.multimodal.inputs import as_kwargs, \ - from_items, MultiModalFieldElem, build_elems +from vllm_mindspore.multimodal.inputs import (as_kwargs, batched_reduce_data, + flat_reduce_data, from_items, + MultiModalFieldElem, _try_stack) -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalBatchedField from vllm.multimodal.inputs import MultiModalFlatField +from vllm.multimodal.inputs import MultiModalKwargs +MultiModalBatchedField._reduce_data = batched_reduce_data +MultiModalFlatField._reduce_data = flat_reduce_data MultiModalKwargs.as_kwargs = as_kwargs MultiModalKwargs.from_items = from_items +MultiModalKwargs._try_stack = _try_stack -MultiModalFlatField.build_elems = build_elems vllm.multimodal.inputs.MultiModalFieldElem = MultiModalFieldElem vllm.v1.serial_utils.MultiModalFieldElem = MultiModalFieldElem diff --git a/vllm_mindspore/multimodal/inputs.py b/vllm_mindspore/multimodal/inputs.py index 116665b0..95566d49 100644 --- a/vllm_mindspore/multimodal/inputs.py +++ b/vllm_mindspore/multimodal/inputs.py @@ -19,14 +19,15 @@ # limitations under the License. """Adaption for mindspore.""" from collections import defaultdict -from collections.abc import Sequence from dataclasses import dataclass from typing import Union, cast import mindspore +import numpy as np +import torch from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import (BaseMultiModalField, BatchedTensorInputs, - JSONTree, json_map_leaves, + JSONTree, is_list_of, json_map_leaves, nested_tensors_equal) NestedTensors = Union[ @@ -110,11 +111,85 @@ def from_items(items): return MultiModalKwargs(data, items=items) -def build_elems( - self, - modality: str, - key: str, - data: NestedTensors, -) -> Sequence[MultiModalFieldElem]: - field_factory = self._field_factory(modality=modality, key=key) - return [field_factory(data[cast(slice, s)]) for s in self.slices] +def batched_reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + # NOTE: vLLM-MindSpore Plugin: + # Currently mindspore does not support operating tensors in a + # multi-threaded environment, so convert tensors to numpy. + if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + if len(batch) == 1: + # An optimization when `batch` contains only one tensor: + # - produce exactly same result as `torch.stack(batch)` + # - will achieve zero-copy if the tensor is contiguous + return mindspore.from_numpy(np.expand_dims(batch[0].numpy(), 0)) + first_shape = batch[0].shape + if all(elem.shape == first_shape for elem in batch): + return mindspore.from_numpy(np.stack([b.numpy() for b in batch])) + + return batch + + +def flat_reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + # NOTE: vLLM-MindSpore Plugin: + # Currently mindspore does not support operating tensors in a + # multi-threaded environment, so convert tensors to numpy. + if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + if len(batch) == 1: + # An optimization when `batch` contains only one tensor: + # - produce exactly same result as `torch.concat(batch)` + # - will achieve zero-copy if the tensor is contiguous + return mindspore.from_numpy(batch[0].numpy()) + + def _expect_same_shape(tensor: torch.Tensor): + return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:] + + first_shape = _expect_same_shape(batch[0]) + + if all(_expect_same_shape(elem) == first_shape for elem in batch): + return mindspore.from_numpy( + np.concatenat([b.numpy() for b in batch], axis=self.dim)) + + assert self.dim == 0, "dim == 0 is required for nested list" + return [e for elem in batch for e in elem] + + +@staticmethod +def _try_stack(nested_tensors: NestedTensors, + pin_memory: bool = False) -> NestedTensors: + """ + Stack the inner dimensions that have the same shape in + a nested list of tensors. + + Thus, a dimension represented by a list means that the inner + dimensions are different for each element along that dimension. + """ + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors + + # TODO: Remove these once all models have been migrated + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) + + stacked = [ + MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors + ] + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. + return stacked + + # NOTE: vLLM-MindSpore Plugin: + # Currently mindspore does not support operating tensors in a + # multi-threaded environment, so convert tensors to numpy. + tensors_ = cast(list[torch.Tensor], stacked) + if len(tensors_) == 1: + # An optimization when `tensors_` contains only one tensor: + # - produce exactly same result as `torch.stack(tensors_)` + # - will achieve zero-copy if the tensor is contiguous + return mindspore.from_numpy(np.expand_dims(tensors_[0].numpy(), 0)) + + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ + + return mindspore.from_numpy(np.stack([t.numpy() for t in tensors_])) diff --git a/vllm_mindspore/v1/serial_utils.py b/vllm_mindspore/v1/serial_utils.py index 45204af6..07b49df7 100644 --- a/vllm_mindspore/v1/serial_utils.py +++ b/vllm_mindspore/v1/serial_utils.py @@ -46,8 +46,10 @@ def _encode_tensor( ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None # view the tensor as a contiguous 1D array of bytes - arr = obj.flatten().contiguous().numpy() - arr = arr.view(dtype=np.uint8) + # NOTE: vLLM-MindSpore Plugin: + # Currently mindspore does not support operating tensors in a + # multi-threaded environment, so convert tensors to numpy. + arr = obj.numpy().flatten().view(dtype=np.uint8) if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. CUSTOM_TYPE_RAW_VIEW = 3 -- Gitee