From 833b547eb366952f1a2dd4c61145f1d5a317d266 Mon Sep 17 00:00:00 2001 From: MNxyz Date: Fri, 7 Nov 2025 09:19:47 +0000 Subject: [PATCH 1/2] update mindscience/models/neural_operator/vit_kno.py. Signed-off-by: MNxyz --- mindscience/models/neural_operator/vit_kno.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mindscience/models/neural_operator/vit_kno.py b/mindscience/models/neural_operator/vit_kno.py index 48039d3e2..b72d0f957 100644 --- a/mindscience/models/neural_operator/vit_kno.py +++ b/mindscience/models/neural_operator/vit_kno.py @@ -68,6 +68,39 @@ class ViTKNO(nn.Cell): >>> print(output.shape) (16, 128, 5120) + # 下面是新增的补充示例,可直接运行用于 smoke-test: + >>> # Minimal runnable example (supplementary) + >>> import numpy as _np + >>> from mindspore import Tensor as _Tensor, dtype as _dtype + >>> from mindspore.common.initializer import Normal as _Normal + >>> from mindspore.common.initializer import initializer as _initializer + >>> # Note: if your environment lacks MindSpore runtime, this block is illustrative. + >>> B, C, H, W = 2, 1, 64, 64 + >>> data = _np.random.randn(B, C, H, W).astype(_np.float32) + >>> tensor_in = _Tensor(data, dtype=_dtype.float32) + >>> model = ViTKNO( + ... image_size=(H, W), + ... patch_size=8, + ... in_channels=C, + ... out_channels=C, + ... encoder_embed_dims=64, + ... encoder_depths=2, + ... mlp_ratio=2, + ... dropout_rate=1.0, + ... drop_path_rate=0.0, + ... num_blocks=4, + ... settings="MLP", + ... high_freq=True, + ... encoder_network=False, + ... compute_dtype=_dtype.float32 + ... ) + >>> out, recons = model(tensor_in) + >>> print("out shape:", out.shape) + >>> print("recons shape:", recons.shape) + # Expected shapes (batch, channels, H, W) for both out and recons, e.g.: + # out shape: (2, 1, 64, 64) + # recons shape: (2, 1, 64, 64) + """ def __init__( -- Gitee From 47e61976c35ee6d1a5f46f50d20b92f7cdd2b620 Mon Sep 17 00:00:00 2001 From: MNxyz Date: Fri, 7 Nov 2025 10:03:37 +0000 Subject: [PATCH 2/2] update mindscience/models/neural_operator/vit_kno.py. Signed-off-by: MNxyz --- mindscience/models/neural_operator/vit_kno.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindscience/models/neural_operator/vit_kno.py b/mindscience/models/neural_operator/vit_kno.py index b72d0f957..15bd48481 100644 --- a/mindscience/models/neural_operator/vit_kno.py +++ b/mindscience/models/neural_operator/vit_kno.py @@ -145,8 +145,8 @@ class ViTKNO(nn.Cell): self.pos_embed = Parameter(np.zeros((1, num_patches, encoder_embed_dims))) self.pos_drop = nn.Dropout(dropout_rate) - dpr = [x for x in ops.linspace(Tensor(0, mindspore.float32), Tensor(drop_path_rate, mindspore.float32), - self.encoder_depths)] + dpr = list(ops.linspace(Tensor(0, mindspore.float32), Tensor(drop_path_rate, mindspore.float32), self.encoder_depths)) + self.blocks = nn.CellList([ AFNOBlock(embed_dims=self.encoder_embed_dims, mlp_ratio=mlp_ratio, dropout_rate=dropout_rate, -- Gitee