diff --git a/.jenkins/test/config/flow_config/dependent_packages.yaml b/.jenkins/test/config/flow_config/dependent_packages.yaml index 9e70df642d476be60aeb9313c8938893f69eee28..574dc9b2fba132814d611cc5f771014c52a28d43 100644 --- a/.jenkins/test/config/flow_config/dependent_packages.yaml +++ b/.jenkins/test/config/flow_config/dependent_packages.yaml @@ -1,2 +1,2 @@ mindspore: - '/mindspore/mindspore/version/202503/20250326/master_20250326010019_b91eca2945e61641319f9887aa76a1ccb38604d3_newest/' \ No newline at end of file + '/mindspore/mindspore/version/202506/20250603/master_20250603091707_20e3faa947757ae90617a0300d5dc8d70398ec63_newest/' diff --git a/MindFlow/mindflow/cell/__init__.py b/MindFlow/mindflow/cell/__init__.py index d71670098062ee05380aa9b73dc81595d44998e0..69fab7c56a0ce21859e16ec699582ef381c315e5 100644 --- a/MindFlow/mindflow/cell/__init__.py +++ b/MindFlow/mindflow/cell/__init__.py @@ -18,6 +18,7 @@ from .basic_block import LinearBlock, ResBlock, InputScale, FCSequential, MultiS from .neural_operators import FNO1D, FNO2D, FNO3D, KNO1D, KNO2D, PDENet, PeRCNN, SNO, SNO1D, SNO2D, SNO3D from .attention import Attention, MultiHeadAttention, TransformerBlock from .vit import ViT +from .kan import KANLayer, curve2coef from .unet2d import UNet2D from .sno_utils import poly_data, get_poly_transform, interpolate_1d_dataset, interpolate_2d_dataset from .diffusion import DiffusionScheduler, DiffusionTrainer, DDPMScheduler, DDIMScheduler, DDPMPipeline, DDIMPipeline @@ -26,6 +27,6 @@ from .diffusion_transformer import DiffusionTransformer, ConditionDiffusionTrans __all__ = ["get_activation", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "UNet2D", "PeRCNN", "SNO", "SNO1D", "SNO2D", "SNO3D", "Attention", "MultiHeadAttention", "TransformerBlock", "ViT", "DDPMPipeline", "DDIMPipeline", "DiffusionTrainer", "DiffusionScheduler", "DDPMScheduler", - "DDIMScheduler", "DiffusionTransformer", "ConditionDiffusionTransformer"] + "DDIMScheduler", "DiffusionTransformer", "ConditionDiffusionTransformer", "KANLayer", "curve2coef"] __all__.extend(basic_block.__all__) __all__.extend(sno_utils.__all__) diff --git a/MindFlow/mindflow/cell/kan.py b/MindFlow/mindflow/cell/kan.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5021a1e32006815557502a3a8f672e8ca3aee1 --- /dev/null +++ b/MindFlow/mindflow/cell/kan.py @@ -0,0 +1,485 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""KAN api""" +from mindspore import Tensor, Parameter, ops, nn, mint +from mindspore.scipy.linalg import lstsq +from mindspore.common import dtype as mstype +import numpy as np +# pylint: disable=C0103 +# pylint: disable=W0102 + + +def sparse_mask(in_dim, out_dim): + ''' + get sparse mask + ''' + in_coord = ops.arange(in_dim) * 1/in_dim + 1/(2*in_dim) + out_coord = ops.arange(out_dim) * 1/out_dim + 1/(2*out_dim) + + dist_mat = ops.abs(out_coord[:, None] - in_coord[None, :]) + in_nearest = ops.argmin(dist_mat, axis=0) + in_connection = ops.stack([ops.arange(in_dim), in_nearest]).T + out_nearest = ops.argmin(dist_mat, axis=1) + out_connection = ops.stack([out_nearest, ops.arange(out_dim)]).T + all_connection = ops.cat([in_connection, out_connection], axis=0) + mask = ops.zeros((in_dim, out_dim), mstype.float32) + mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 + + return mask + + +def B_batch(x, grid, k=0): + ''' + evaludate x on B-spline bases + + Args: + ----- + x : 2D Tensor + inputs, shape (number of splines, number of samples) + grid : 2D Tensor + grids, shape (number of splines, number of grid points) + k : int + the piecewise polynomial order of splines. + extend : bool + If True, k points are extended on both ends. If False, no extension (zero boundary condition). + Default: True + + Returns: + -------- + spline values : 3D Tensor + shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. + + Example + ------- + >>> from kan.spline import B_batch + >>> x = ops.rand(100,2) + >>> grid = ops.linspace(-1, 1, steps=11)[None, :].expand(2, 11) + >>> B_batch(x, grid, k=3).shape + ''' + + x = x.expand_dims(axis=2) + grid = grid.expand_dims(axis=0) + + if k == 0: + value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) + else: + B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1) + term1 = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - + grid[:, :, :-(k + 1)] + 1e-8) * B_km1[:, :, :-1] + term2 = (grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - + grid[:, :, 1:(-k)] + 1e-8) * B_km1[:, :, 1:] + value = term1 + term2 + + return ops.nan_to_num(value) + + +def coef2curve(x_eval, grid, coef, k): + ''' + converting B-spline coefficients to B-spline curves. + Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). + + Args: + ----- + x_eval : 2D Tensor + shape (batch, in_dim) + grid : 2D Tensor + shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. + coef : 3D Tensor + shape (in_dim, out_dim, G+k) + k : int + the piecewise polynomial order of splines. + + Returns: + -------- + y_eval : 3D Tensor + shape (batch, in_dim, out_dim) + + ''' + b_splines = B_batch(x_eval, grid, k=k) + y_eval = mint.einsum('ijk,jlk->ijl', b_splines, coef) + return y_eval + + +def curve2coef(x_eval, y_eval, grid, k): + ''' + converting B-spline curves to B-spline coefficients using least squares. + + Args: + ----- + x_eval : 2D Tensor + shape (batch, in_dim) + y_eval : 3D Tensor + shape (batch, in_dim, out_dim) + grid : 2D Tensor + shape (in_dim, grid+2*k) + k : int + spline order + lamb : float + regularized least square lambda + + Returns: + -------- + coef : 3D Tensor + shape (in_dim, out_dim, G+k) + ''' + batch = x_eval.shape[0] + in_dim = x_eval.shape[1] + out_dim = y_eval.shape[2] + n_coef = grid.shape[1] - k - 1 + + mat = B_batch(x_eval, grid, k) + mat = mat.transpose(1, 0, 2).expand_dims( + axis=1).broadcast_to((in_dim, out_dim, batch, n_coef)) + y_eval = y_eval.transpose(1, 2, 0).expand_dims(axis=3) + coef = lstsq(mat, y_eval)[0][:, :, :, 0] + return coef + + +def extend_grid(grid, k_extend=0): + ''' + extend grid + ''' + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + + for _ in range(k_extend): + grid = ops.cat([grid[:, [0]] - h, grid], axis=1) + grid = ops.cat([grid, grid[:, [-1]] + h], axis=1) + + return grid + + +class KANLayer(nn.Cell): + """ + KANLayer class + + + Attributes: + ----------- + in_dim: int + input dimension + out_dim: int + output dimension + num: int + the number of grid intervals + k: int + the piecewise polynomial order of splines + noise_scale: float + spline scale at initialization + coef: 2D Tensor + coefficients of B-spline bases + scale_base_mu: float + magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu + scale_base_sigma: float + magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma + scale_sp: float + mangitude of the spline function spline(x) + base_fun: fun + residual function b(x) + mask: 1D Tensor + mask of spline functions. + setting some element of the mask to zero means setting the corresponding activation to zero function. + grid_eps: float in [0,1] + a hyperparameter used in update_grid_from_samples. + When grid_eps = 1, the grid is uniform; when grid_eps = 0, + the grid is partitioned using percentiles of samples. + 0 < grid_eps < 1 interpolates between the two extremes. + the id of activation functions that are locked + + Args: + ----- + x : 2D Tensor + inputs, shape (number of samples, input dimension) + + Returns: + -------- + y : 2D Tensor + outputs, shape (number of samples, output dimension) + preacts : 3D Tensor + fan out x into activations, shape (number of sampels, output dimension, input dimension) + postacts : 3D Tensor + the outputs of activation functions with preacts as inputs + postspline : 3D Tensor + the outputs of spline functions with preacts as inputs + + Example + ------- + >>> from mindflow.cell import KANLayer + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> x = ops.randn(0, 1, size=(100,3)) + >>> y, preacts, postacts, postspline = model(x) + >>> y.shape, preacts.shape, postacts.shape, postspline.shape + """ + + def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, + scale_base_sigma=1.0, scale_sp=1.0, base_fun=nn.SiLU(), grid_eps=0.02, + grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, sparse_init=False): + super().__init__() + self.out_dim = out_dim + self.in_dim = in_dim + self.num = num + self.k = k + self.grid_eps = grid_eps + self.base_fun = base_fun + + # Initialize grid + grid_init = np.linspace(grid_range[0], grid_range[1], num=num+1) + grid_init = np.tile(grid_init, (in_dim, 1)) + grid = Tensor(grid_init, mstype.float32) + grid = extend_grid(grid, k_extend=k) + self.grid = Parameter(grid, requires_grad=False, name="grid") + + # Initialize coefficients with noise + noises = (np.random.rand(num+1, in_dim, out_dim) - 0.5) * \ + noise_scale / num + noises = Tensor(noises, mstype.float32) + + # Use fixed grid for initial coef calculation + static_grid = self.grid[:, k:-k].transpose(1, 0) + self.coef = Parameter( + curve2coef(static_grid, noises, self.grid, k), + name="coef" + ) + + # Initialize mask + if sparse_init: + mask_val = sparse_mask(in_dim, out_dim) + else: + mask_val = ops.ones((in_dim, out_dim), mstype.float32) + self.mask = Parameter(mask_val, requires_grad=False, name="mask") + + # Initialize scales + scale_base_init = scale_base_mu / np.sqrt(in_dim) + \ + scale_base_sigma * (np.random.rand(in_dim, out_dim) + * 2 - 1) / np.sqrt(in_dim) + self.scale_base = Parameter( + Tensor(scale_base_init, mstype.float32), + requires_grad=sb_trainable, + name="scale_base" + ) + + scale_sp_init = np.ones((in_dim, out_dim)) * \ + scale_sp / np.sqrt(in_dim) * mask_val.asnumpy() + self.scale_sp = Parameter( + Tensor(scale_sp_init, mstype.float32), + requires_grad=sp_trainable, + name="scale_sp" + ) + + def construct(self, x): + """construct""" + batch = x.shape[0] + preacts = x.expand_dims(axis=1).broadcast_to( + (batch, self.out_dim, self.in_dim)) + + base = self.base_fun(x) + y = coef2curve(x, self.grid, self.coef, self.k) + postspline = y.transpose(0, 2, 1) + + # Combine base and spline + scale_base_term = self.scale_base.expand_dims( + axis=0) * base.expand_dims(axis=2) + scale_sp_term = self.scale_sp.expand_dims(axis=0) * y + y = scale_base_term + scale_sp_term + y = self.mask.expand_dims(axis=0) * y + + postacts = y.transpose(0, 2, 1) + y = ops.sum(y, dim=1) # Sum over input dimension + + return y, preacts, postacts, postspline + + def update_grid_from_samples(self, x, mode='sample'): + ''' + update grid from samples + + Args: + ----- + x : 2D Tensor + inputs, shape (number of samples, input dimension) + + Returns: + -------- + None + + Example + ------- + >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) + >>> print(model.grid.data) + >>> x = ops.linspace(-3,3,steps=100)[:,None] + >>> model.update_grid_from_samples(x) + >>> print(model.grid.data) + ''' + x_pos, _ = ops.sort(x, axis=0) + y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) + num_interval = self.grid.shape[1] - 1 - 2 * self.k + + def get_grid(num_interval): + ids = [int(x.shape[0] / num_interval * i) for i in range(num_interval)] + [-1] + grid_adaptive = x_pos[ids].transpose(1, 0) + margin = 0.00 + h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_interval + grid_uniform = grid_adaptive[:, [0]] - margin + \ + h * ops.arange(num_interval + 1).expand_dims(0) + grid = self.grid_eps * grid_uniform + \ + (1 - self.grid_eps) * grid_adaptive + return grid + + grid = get_grid(num_interval) + + if mode == 'grid': + sample_grid = get_grid(2 * num_interval) + x_pos = sample_grid.transpose() + y_eval = coef2curve(x_pos, self.grid.asnumpy(), + self.coef.asnumpy(), self.k) + + self.grid.set_data(extend_grid(grid, k_extend=self.k)) + self.coef.set_data(curve2coef(x_pos, y_eval, self.grid, self.k)) + + def initialize_grid_from_parent(self, parent, x, mode='sample'): + ''' + Initialize grid from a parent KANLayer & samples + + Args: + ----- + parent : KANLayer + a parent KANLayer (whose grid is usually coarser than the current model) + x : 2D Tensor + inputs, shape (number of samples, input dimension) + + Returns: + -------- + None + + Example + ------- + >>> batch = 100 + >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) + >>> print(parent_model.grid.data) + >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) + >>> x = ops.randn((batch, 1)) + >>> model.initialize_grid_from_parent(parent_model, x) + >>> print(model.grid.data) + ''' + x_pos = ops.sort(x, axis=0)[0] + y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) + + def get_grid(): + # Create a temporary KANLayer for interpolation + sp2 = KANLayer( + in_dim=1, + out_dim=self.in_dim, + k=1, + num=parent.grid.shape[1] - 2 * parent.k - + 1, # G = num_grid_points - 2k -1 + scale_base_mu=0.0, + scale_base_sigma=0.0 + ) + + # Get parent grid points + x_pos = parent.grid[:, parent.k:-parent.k] + + # Prepare input for curve2coef + static_grid = sp2.grid[:, sp2.k:-sp2.k].transpose(1, 0) + static_grid = static_grid.broadcast_to( + (static_grid.shape[0], self.in_dim)) + + # Compute coefficients for interpolation + sp2_coef = curve2coef( + static_grid, + x_pos.transpose(1, 0).expand_dims(axis=2), + sp2.grid, + k=1 + ).transpose(1, 0, 2) + + sp2.coef.set_data(sp2_coef) + + # Generate new grid points using interpolation + percentile = Tensor(np.linspace(-1, 1, self.num + 1), + mstype.float32).expand_dims(axis=1) + grid_new = sp2(percentile)[0].transpose(1, 0) + return grid_new + + if mode == 'grid': + sample_grid = get_grid() + x_pos = sample_grid.permute(1, 0) + y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) + + # Set current grid and coefficients + grid = get_grid() + grid_extended = extend_grid(grid, k_extend=self.k) + self.grid.set_data(grid_extended) + self.coef.set_data(curve2coef( + x_pos, + y_eval, + self.grid, + self.k + )) + + def get_subset(self, in_id, out_id): + ''' + get a smaller KANLayer from a larger KANLayer (used for pruning) + + Args: + ----- + in_id : list + id of selected input neurons + out_id : list + id of selected output neurons + + Returns: + -------- + spb : KANLayer + + Example + ------- + >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) + >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) + >>> kanlayer_small.in_dim, kanlayer_small.out_dim + (2, 3) + ''' + spb = KANLayer(len(in_id), len(out_id), self.num, + self.k, base_fun=self.base_fun) + spb.grid.set_data(self.grid[in_id]) + spb.coef.set_data(self.coef[in_id][:, out_id]) + spb.scale_base.set_data(self.scale_base[in_id][:, out_id]) + spb.scale_sp.set_data(self.scale_sp[in_id][:, out_id]) + spb.mask.set_data(self.mask[in_id][:, out_id]) + return spb + + def swap(self, i1, i2, mode='in'): + '''Helper function to swap tensor elements''' + def swap_tensor(data, i1, i2, mode): + if mode == 'in': + # Swap entire rows + row1 = data[i1].copy() + row2 = data[i2].copy() + data[i1] = row2 + data[i2] = row1 + elif mode == 'out': + # Swap entire columns + col1 = data[:, i1].copy() + col2 = data[:, i2].copy() + data[:, i1] = col2 + data[:, i2] = col1 + return data + + # Swap grid if mode is 'in' + if mode == 'in': + grid_data = self.grid.value() + self.grid.set_data(swap_tensor(grid_data, i1, i2, 'in')) + + # Swap other parameters + for param in [self.coef, self.scale_base, self.scale_sp, self.mask]: + param_data = param.value() + param.set_data(swap_tensor(param_data, i1, i2, mode)) diff --git a/tests/st/mindflow/cell/kan/data/curve2coef.npz b/tests/st/mindflow/cell/kan/data/curve2coef.npz new file mode 100644 index 0000000000000000000000000000000000000000..69d4c3c2710e37db6ec3cd4edbf795a46a65dacd Binary files /dev/null and b/tests/st/mindflow/cell/kan/data/curve2coef.npz differ diff --git a/tests/st/mindflow/cell/kan/data/init_output.npz b/tests/st/mindflow/cell/kan/data/init_output.npz new file mode 100644 index 0000000000000000000000000000000000000000..e4b0decef55cf973199c028d3a87b12decbffa71 Binary files /dev/null and b/tests/st/mindflow/cell/kan/data/init_output.npz differ diff --git a/tests/st/mindflow/cell/kan/data/init_x.npy b/tests/st/mindflow/cell/kan/data/init_x.npy new file mode 100644 index 0000000000000000000000000000000000000000..a76d58c18a5d5c40453f1270ece179b9061af6fb Binary files /dev/null and b/tests/st/mindflow/cell/kan/data/init_x.npy differ diff --git a/tests/st/mindflow/cell/kan/data/kan.ckpt b/tests/st/mindflow/cell/kan/data/kan.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..580fd170136cc3abd311bd1ca955284197d37f04 Binary files /dev/null and b/tests/st/mindflow/cell/kan/data/kan.ckpt differ diff --git a/tests/st/mindflow/cell/kan/data/kan.npz b/tests/st/mindflow/cell/kan/data/kan.npz new file mode 100644 index 0000000000000000000000000000000000000000..aa5f4454dc8d198362d8bc50d1f619cf23e48182 Binary files /dev/null and b/tests/st/mindflow/cell/kan/data/kan.npz differ diff --git a/tests/st/mindflow/cell/kan/data/kan.pt b/tests/st/mindflow/cell/kan/data/kan.pt new file mode 100644 index 0000000000000000000000000000000000000000..7cce7eda21c0d525429ec5eb302767066f5d5d46 Binary files /dev/null and b/tests/st/mindflow/cell/kan/data/kan.pt differ diff --git a/tests/st/mindflow/cell/kan/data/subset.npz b/tests/st/mindflow/cell/kan/data/subset.npz new file mode 100644 index 0000000000000000000000000000000000000000..99452536b0c63d3f8759cdc68b5d541ada79f7dd Binary files /dev/null and b/tests/st/mindflow/cell/kan/data/subset.npz differ diff --git a/tests/st/mindflow/cell/kan/data/update_output.npz b/tests/st/mindflow/cell/kan/data/update_output.npz new file mode 100644 index 0000000000000000000000000000000000000000..50cb2ce13b2c9f30ef8e42a5c7bebfc5d5f68043 Binary files /dev/null and b/tests/st/mindflow/cell/kan/data/update_output.npz differ diff --git a/tests/st/mindflow/cell/kan/test_kan.py b/tests/st/mindflow/cell/kan/test_kan.py new file mode 100644 index 0000000000000000000000000000000000000000..07dea767df0e9c00f7f6460c2ae2023d40a17732 --- /dev/null +++ b/tests/st/mindflow/cell/kan/test_kan.py @@ -0,0 +1,141 @@ +"""test KAN""" +import sys +import os +import pytest + +from mindspore import ops, Tensor, load_checkpoint, load_param_into_net, context +import numpy as np +from mindflow.cell import curve2coef, KANLayer + +# pylint: disable=C0413 +PROJECT_ROOT = os.path.abspath(os.path.join( + os.path.dirname(__file__), "../../../")) +sys.path.append(PROJECT_ROOT) +from common.cell import compare_output + +CKPT_PATH = './data/kan.ckpt' + + +def load_kan_data(): + data = np.load('./data/kan.npz') + y_gt = data['y'] + x = data['x'] + preacts_gt, postacts_gt, postspline_gt = data['preacts'], data['postacts'], data['postspline'] + return x, y_gt, preacts_gt, postacts_gt, postspline_gt + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_curve2coef(mode): + """ + Feature: curve2coef + Description: test curve2coef + Expectation: success + """ + context.set_context(mode=mode) + func = curve2coef + + def load_data(): + data = np.load('./data/curve2coef.npz') + x, y, grid, k, out_gt = data['x'], data['y'], data['grid'], data['k'], data['out'] + return x, y, grid, k, out_gt + x, y, grid, k, out_gt = load_data() + k = k.item() + out = func(Tensor(x), Tensor(y), Tensor(grid), k) + assert compare_output(out_gt, out.numpy()) + print('test_curve2coef pass') + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_kan_forward(mode): + """ + Feature: KAN + Description: test KAN forwardresult + Expectation: success + """ + context.set_context(mode=mode) + x, y_gt, preacts_gt, postacts_gt, postspline_gt = load_kan_data() + net = KANLayer() + params = load_checkpoint(CKPT_PATH) + load_param_into_net(net, params) + y, preacts, postacts, postspline = net(Tensor(x)) + assert compare_output(y.numpy(), y_gt) + assert compare_output(preacts.numpy(), preacts_gt) + assert compare_output(postacts.numpy(), postacts_gt) + assert compare_output(postspline.numpy(), postspline_gt) + print('test_kan_forward pass') + + +def load_grid_coef(npz_file): + data = np.load(npz_file) + return data['grid'], data['coef'] + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_update_grid_from_samples(mode): + """ + Feature: KAN update_grid_from_samples + Description: test KAN update_grid_from_samples + Expectation: success + """ + context.set_context(mode=mode) + layer = KANLayer() + load_param_into_net(layer, load_checkpoint(CKPT_PATH)) + x = ops.linspace(-3, 3, steps=100)[:, None] + layer.update_grid_from_samples(x) + grid, coef = load_grid_coef('./data/update_output.npz') + assert compare_output(layer.grid.data.numpy(), grid) + assert compare_output(layer.coef.data.numpy(), coef) + print('test_update_grid_from_samples pass') + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_initialize_grid_from_parent(mode): + """ + Feature: initialize_grid_from_parent + Description: test KAN initialize_grid_from_parent + Expectation: success + """ + context.set_context(mode=mode) + layer = KANLayer() + load_param_into_net(layer, load_checkpoint(CKPT_PATH)) + parent = KANLayer() + load_param_into_net(parent, load_checkpoint(CKPT_PATH)) + x = Tensor(np.load('./data/init_x.npy')) + layer.initialize_grid_from_parent(parent, x) + # print('initialize_grid_from_parent', layer.grid.data, layer.coef.data) + grid, coef = load_grid_coef('./data/init_output.npz') + assert compare_output(layer.grid.data.numpy(), grid) + assert compare_output(layer.coef.data.numpy(), coef) + print('test_initialize_grid_from_parent pass') + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_get_subset(mode): + """ + Feature: KAN get_subset + Description: test KAN get_subset + Expectation: success + """ + context.set_context(mode=mode) + layer = KANLayer() + load_param_into_net(layer, load_checkpoint(CKPT_PATH)) + child = layer.get_subset([1, 2], [1]) + grid, coef = load_grid_coef('./data/subset.npz') + assert compare_output(child.grid.data.numpy(), grid) + assert compare_output(child.coef.data.numpy(), coef) + print('test_get_subset pass')