From d06e52a36933ca66436088ebfe3a17f8ec9611c4 Mon Sep 17 00:00:00 2001 From: one_east Date: Fri, 7 Nov 2025 11:03:29 +0800 Subject: [PATCH] bugfix: np.concatenat error --- tests/st/python/test_multi_modal_inputs.py | 81 ++++++++++++++++++++++ vllm_mindspore/multimodal/inputs.py | 2 +- 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 tests/st/python/test_multi_modal_inputs.py diff --git a/tests/st/python/test_multi_modal_inputs.py b/tests/st/python/test_multi_modal_inputs.py new file mode 100644 index 000000000..b0ac24ee8 --- /dev/null +++ b/tests/st/python/test_multi_modal_inputs.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test multi-modal inputs function.""" + +# isort: skip_file + +import vllm_mindspore + +from dataclasses import dataclass +from typing import Union +import mindspore +import torch +import pytest +import numpy as np +from vllm.multimodal.inputs import MultiModalFlatField, NestedTensors + + +@dataclass(frozen=True) +class MockMultiModalFlatField(MultiModalFlatField): + slices: Union[list[slice], list[list[slice]]] + dim: int = 0 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_reduce_data_with_torch_tensors(): + batch = [torch.randn(2, 3, 4), torch.randn(2, 3, 4), torch.randn(2, 3, 4)] + field = MockMultiModalFlatField(slices=[slice(None)], dim=1) + result = field._reduce_data(batch) + assert isinstance(result, mindspore.Tensor) + assert result.shape == (2, 9, 4) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_reduce_data_with_single_tensor(): + batch = [torch.randn(2, 3, 4)] + field = MockMultiModalFlatField(slices=[slice(None)], dim=2) + result = field._reduce_data(batch) + assert isinstance(result, mindspore.Tensor) + assert result.shape == (2, 3, 4) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_reduce_data_with_nested_tensors(): + first_tensor = torch.randn(3, 5) + second_tensor = torch.randn(3, 5) + batch = (first_tensor, second_tensor) + field = MockMultiModalFlatField(slices=[slice(None)], dim=0) + result = field._reduce_data(batch) + assert len(result) == 6 + assert result[0].size(0) == 5 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_reduce_data_with_mismatched_dim(): + first_tensor = torch.randn(3, 5) + second_tensor = torch.randn(3, 5) + batch = (first_tensor, second_tensor) + field = MockMultiModalFlatField(slices=[slice(None)], dim=1) + with pytest.raises(AssertionError): + field._reduce_data(batch) diff --git a/vllm_mindspore/multimodal/inputs.py b/vllm_mindspore/multimodal/inputs.py index 7b7436ab6..bf3509516 100644 --- a/vllm_mindspore/multimodal/inputs.py +++ b/vllm_mindspore/multimodal/inputs.py @@ -157,7 +157,7 @@ def flat_reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if all(_expect_same_shape(elem) == first_shape for elem in batch): return mindspore.from_numpy( - np.concatenat([b.numpy() for b in batch], axis=self.dim)) + np.concatenate([b.numpy() for b in batch], axis=self.dim)) assert self.dim == 0, "dim == 0 is required for nested list" return [e for elem in batch for e in elem] -- Gitee