diff --git a/tests/models/ sp_transform/.keep b/tests/models/ sp_transform/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/models/ sp_transform/test_ sp_transform.py b/tests/models/ sp_transform/test_ sp_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..de5cc6538be17390152fc98580152a9de3351177 --- /dev/null +++ b/tests/models/ sp_transform/test_ sp_transform.py @@ -0,0 +1,195 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""spectral transform testcase""" + +import os +import sys +import pytest +import numpy as np + +import mindspore as ms +from mindspore import Tensor, set_seed +from mindspore import dtype as mstype + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position +from mindflow.cell.neural_operators.sp_transform import ( + ConvCell, TransformCell, TransformCell1D, TransformCell2D, TransformCell3D, Dim +) +# pylint: enable=wrong-import-position + +set_seed(123456) + + +# ====================================================================== +# ConvCell Tests +# ====================================================================== +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_convcell_forward(mode): + """ + Feature: Test ConvCell in 1D/2D/3D. + Description: Check output shape for conv1d, conv2d, conv3d. + Expectation: Output shape matches input shape. + """ + ms.set_context(mode=mode) + + # 1D conv + x1 = Tensor(np.random.randn(2, 3, 16), mstype.float32) + net1 = ConvCell(dim=1, in_channels=3, out_channels=4, kernel_size=3) + y1 = net1(x1) + assert y1.shape == (2, 4, 16) + + # 2D conv + x2 = Tensor(np.random.randn(2, 3, 16, 16), mstype.float32) + net2 = ConvCell(dim=2, in_channels=3, out_channels=5, kernel_size=3) + y2 = net2(x2) + assert y2.shape == (2, 5, 16, 16) + + # 3D conv + x3 = Tensor(np.random.randn(2, 3, 8, 8, 8), mstype.float32) + net3 = ConvCell(dim=3, in_channels=3, out_channels=6, kernel_size=3) + y3 = net3(x3) + assert y3.shape == (2, 6, 8, 8, 8) + + +# ====================================================================== +# TransformCell1D +# ====================================================================== +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_transformcell1d(mode): + """ + Feature: Test TransformCell1D. + Description: Polynomial transform in 1D. + Expectation: matmul result has correct shape. + """ + ms.set_context(mode=mode) + + res = 16 + modes = 5 + batch, ch = 2, 3 + + transform = Tensor(np.random.randn(modes, res), mstype.float32) + net = TransformCell1D(transform) + + x = Tensor(np.random.randn(batch, ch, res), mstype.float32) + y = net(x) + + assert y.shape == (batch, ch, modes) + + +# ====================================================================== +# TransformCell2D +# ====================================================================== +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize("axis", [Dim.x, Dim.y]) +@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_transformcell2d(mode, axis): + """ + Feature: Test TransformCell2D. + Description: Check 2D polynomial transform on both axes. + Expectation: Output shapes match transform rules. + """ + ms.set_context(mode=mode) + + res = 12 + modes = 4 + batch, ch = 2, 3 + + transform = Tensor(np.random.randn(modes, res), mstype.float32) + net = TransformCell2D(transform, axis) + + x = Tensor(np.random.randn(batch, ch, res, res), mstype.float32) + y = net(x) + + # For both axes, last transformed dimension becomes modes + assert y.shape in ( + (batch, modes, ch, res), # axis x + (batch, res, modes, ch) # axis y + ) + + +# ====================================================================== +# TransformCell3D +# ====================================================================== +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize("axis", [Dim.x, Dim.y, Dim.z]) +@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_transformcell3d(mode, axis): + """ + Feature: Test TransformCell3D forward. + Description: Ensure shape correctness for 3D polynomial transform. + Expectation: One dimension is replaced with modes. + """ + ms.set_context(mode=mode) + + res = 10 + modes = 6 + batch, ch = 1, 2 + + transform = Tensor(np.random.randn(modes, res), mstype.float32) + net = TransformCell3D(transform, axis) + + x = Tensor(np.random.randn(batch, ch, res, res, res), mstype.float32) + y = net(x) + + # validate one dimension equals modes + assert modes in y.shape + + +# ====================================================================== +# TransformCell (Unified Interface) +# ====================================================================== +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize("dim", [1, 2, 3]) +@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_transformcell_unified(dim, mode): + """ + Feature: Test TransformCell wrapper dispatcher. + Description: Validate correct dim selection. + Expectation: Output has modes in one axis. + """ + ms.set_context(mode=mode) + + res = 10 + modes = 4 + batch, ch = 2, 3 + transform = Tensor(np.random.randn(modes, res), mstype.float32) + + net = TransformCell(dim=dim, transform=transform, axis=Dim.x) + + if dim == 1: + x = Tensor(np.random.randn(batch, ch, res), mstype.float32) + elif dim == 2: + x = Tensor(np.random.randn(batch, ch, res, res), mstype.float32) + else: + x = Tensor(np.random.randn(batch, ch, res, res, res), mstype.float32) + + y = net(x) + + assert modes in y.shape