diff --git a/mindscience/models/neural_operator/pdenet.py b/mindscience/models/neural_operator/pdenet.py index 3c0550af1d0059eb65f18a025a2bacc9952a0a43..b5041879a9b3b5c900d2c2f38cbbfcf930563028 100644 --- a/mindscience/models/neural_operator/pdenet.py +++ b/mindscience/models/neural_operator/pdenet.py @@ -75,6 +75,70 @@ class PDENet(nn.Cell): ``Ascend`` ``GPU`` Examples: + import numpy as np + import pytest + from mindspore import Tensor + import mindspore.common.dtype as mstype + from mindscience.models.neural_operator.pdenet import PDENet + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend_training + @pytest.mark.env_onecard + def test_pdenet_forward(): + height, width, channels = 8, 8, 2 + kernel_size, max_order = 3, 2 + + # Construct network + net = PDENet( + height=height, + width=width, + channels=channels, + kernel_size=kernel_size, + max_order=max_order, + dx=0.01, + dy=0.01, + dt=0.01, + periodic=True, + enable_moment=True, + if_fronzen=False + ) + + # Random input + np_input = np.random.rand(1, channels, height, width).astype(np.float32) + x = Tensor(np_input, mstype.float32) + + # Forward pass + out = net(x) + + # Shape check + assert out.shape == (1, channels, height, width) + + # Type check + assert out.dtype == mstype.float32 + + # Ensure numerical values are finite + assert np.isfinite(out.asnumpy()).all() + + def test_pdenet_no_moment(): + net = PDENet( + height=4, + width=4, + channels=1, + kernel_size=3, + max_order=1, + enable_moment=False + ) + + x = Tensor(np.random.rand(1, 1, 4, 4).astype(np.float32)) + y = net(x) + + assert y.shape == (1, 1, 4, 4) + assert np.isfinite(y.asnumpy()).all() + + def test_coefficient_parameter(): + net = PDENet(4, 4, 1, 3, 2) + coe = net.coe + assert coe.shape[0] == net.num_filter - 1 + assert coe.ndim == 3 >>> import numpy as np >>> from mindspore import Tensor >>> import mindspore.common.dtype as mstype @@ -143,6 +207,7 @@ class PDENet(nn.Cell): return self.coe_param def _one_step_forward(self, x): + """Perform one forward step for PDE evolution.""" if self.periodic: x = self._periodicpad(x) @@ -180,6 +245,7 @@ class PDENet(nn.Cell): return out def _init_moment(self): + """Initialize the moment parameters for PDE-Net.""" raw_moment = ms_np.zeros((self.num_filter, self.kernel_size, self.kernel_size)) mask = ms_np.ones((self.num_filter, self.kernel_size, self.kernel_size)) scale = ms_np.ones((self.num_filter,)) @@ -206,6 +272,7 @@ class PDENet(nn.Cell): self.moment = Parameter(raw_moment) def _periodicpad(self, x): + """Apply periodic padding to the input tensor.""" cast = ops.Cast() x = cast(x, self.dtype) x_dim = len(x.shape)