From 94b7205cd7a83645ac8730611cac2777f82c7542 Mon Sep 17 00:00:00 2001 From: Bochengz Date: Thu, 4 Dec 2025 19:39:00 +0800 Subject: [PATCH] add MindFlow/application/data_driven/transolver --- .../data_driven/transolver/exp_elas.py | 231 ++++++++++++++++++ .../transolver/src/datasets/dataset.py | 27 ++ .../transolver/src/models/embedding.py | 32 +++ .../transolver/src/models/model_dict.py | 23 ++ .../src/models/physics_attention.py | 75 ++++++ .../src/models/transolver_irregular_mesh.py | 178 ++++++++++++++ .../transolver/src/utils/normalizer.py | 49 ++++ .../transolver/src/utils/testloss.py | 65 +++++ 8 files changed, 680 insertions(+) create mode 100644 MindFlow/applications/data_driven/transolver/exp_elas.py create mode 100644 MindFlow/applications/data_driven/transolver/src/datasets/dataset.py create mode 100644 MindFlow/applications/data_driven/transolver/src/models/embedding.py create mode 100644 MindFlow/applications/data_driven/transolver/src/models/model_dict.py create mode 100644 MindFlow/applications/data_driven/transolver/src/models/physics_attention.py create mode 100644 MindFlow/applications/data_driven/transolver/src/models/transolver_irregular_mesh.py create mode 100644 MindFlow/applications/data_driven/transolver/src/utils/normalizer.py create mode 100644 MindFlow/applications/data_driven/transolver/src/utils/testloss.py diff --git a/MindFlow/applications/data_driven/transolver/exp_elas.py b/MindFlow/applications/data_driven/transolver/exp_elas.py new file mode 100644 index 000000000..0d67e4340 --- /dev/null +++ b/MindFlow/applications/data_driven/transolver/exp_elas.py @@ -0,0 +1,231 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"exp_elas" +import os +import argparse +import matplotlib.pyplot as plt +import numpy as np +import mindspore as ms +from mindspore import ops +from mindspore import nn +from mindspore.dataset import GeneratorDataset +from tqdm import tqdm + +from src.datasets.dataset import RandomAccessDataset +from src.utils.testloss import TestLoss +from src.models.model_dict import get_model +from src.utils.normalizer import UnitTransformer + +parser = argparse.ArgumentParser('Training Transolver') + +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--epochs', type=int, default=500) +parser.add_argument('--weight_decay', type=float, default=1e-5) +parser.add_argument('--model', type=str, default='TransolverIrregular') +parser.add_argument('--n-hidden', type=int, default=128, help='hidden dim') +parser.add_argument('--n-layers', type=int, default=8, help='layers') +parser.add_argument('--n-heads', type=int, default=8) +parser.add_argument('--batch-size', type=int, default=1) +parser.add_argument("--gpu", type=str, default='0', help="GPU index to use") +parser.add_argument('--max_grad_norm', type=float, default=1.0) +parser.add_argument('--downsample', type=int, default=5) +parser.add_argument('--mlp_ratio', type=int, default=1) +parser.add_argument('--dropout', type=float, default=0.1) +parser.add_argument('--ntrain', type=int, default=1000) +parser.add_argument('--unified_pos', type=int, default=0) +parser.add_argument('--ref', type=int, default=8) +parser.add_argument('--slice_num', type=int, default=64) +parser.add_argument('--eval', type=int, default=0) +parser.add_argument('--save_name', type=str, default='elas_Transolver') +parser.add_argument('--data_path', type=str, default='./data') +args = parser.parse_args() +save_name = args.save_name + + +def count_parameters(model): + """count""" + total_params = sum(p.size for p in model.trainable_params()) + print(f"Total Trainable Params: {total_params}") + return total_params + + +def main(): + """main""" + ntrain = args.ntrain + ntest = 200 + # ms.set_context(mode=ms.PYNATIVE_MODE) + # ms.set_device(device_target='CPU') + + path_sigma = args.data_path + '/elasticity/Meshes/Random_UnitCell_sigma_10.npy' + path_xy = args.data_path + '/elasticity/Meshes/Random_UnitCell_XY_10.npy' + + input_s = np.load(path_sigma).astype(np.float32) + input_s = np.transpose(input_s, (1, 0)) + input_xy = np.load(path_xy).astype(np.float32) + input_xy = np.transpose(input_xy, (2, 0, 1)) + + train_s = input_s[:ntrain] + test_s = input_s[-ntest:] + train_xy = input_xy[:ntrain] + test_xy = input_xy[-ntest:] + + print(input_s.shape, input_xy.shape) + + y_normalizer = UnitTransformer(train_s) + train_s = y_normalizer.encode(train_s) + + train_dataset = RandomAccessDataset(train_xy, train_xy, train_s) + test_dataset = RandomAccessDataset(test_xy, test_xy, test_s) + train_loader = GeneratorDataset(source=train_dataset, column_names=['x', 'fx', 'y'], + shuffle=True).batch(args.batch_size) + test_loader = GeneratorDataset(source=test_dataset, column_names=['x', 'fx', 'y'], + shuffle=False).batch(args.batch_size) + + print("Dataloading is over.") + + model = get_model(args.model)(space_dim=2, + n_layers=args.n_layers, + n_hidden=args.n_hidden, + dropout=args.dropout, + n_head=args.n_heads, + time_input=False, + mlp_ratio=args.mlp_ratio, + fun_dim=0, + out_dim=1, + slice_num=args.slice_num, + ref=args.ref, + unified_pos=args.unified_pos) + step_per_epoch = len(train_loader) + cosine_decay_lr = nn.cosine_decay_lr(min_lr=0., max_lr=args.lr, + total_step=args.epochs * step_per_epoch, + step_per_epoch=step_per_epoch, decay_epoch=args.epochs) + optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=cosine_decay_lr, + weight_decay=args.weight_decay) + + print(args) + print(model) + count_parameters(model) + + myloss = TestLoss(size_average=False) + + def forward_fn(x, y): + """forward fn""" + out = model(x, None).squeeze(-1) + out = y_normalizer.decode(out) + y = y_normalizer.decode(y) + loss = myloss(out, y) + return loss, y + + grad_fn = ops.value_and_grad(forward_fn, None, model.trainable_params(), has_aux=True) + + def train_step(x, y): + """train _step""" + (loss, _), grads = grad_fn(x, y) + grads = ops.clip_by_global_norm(grads, clip_norm=args.max_grad_norm) + optimizer(grads) + return loss + + if args.eval: + ms.load_param_into_net(model, ms.load_checkpoint("./checkpoints/" + save_name + ".ckpt")) + model.set_train(False) + if not os.path.exists('./results/' + save_name + '/'): + os.makedirs('./results/' + save_name + '/') + rel_err = 0.0 + showcase = 2 + cnt = 0 + + for pos, fx, y in test_loader.create_tuple_iterator(): + cnt += 1 + out = model(pos, None).squeeze(-1) + out = y_normalizer.decode(out) + tl = myloss(out, y).asnumpy() + rel_err += tl + if cnt < showcase: + print(cnt) + plt.axis('off') + plt.scatter(x=fx[0, :, 0], y=fx[0, :, 1], + c=y[0, :], cmap='coolwarm') + plt.colorbar() + plt.clim(0, 1000) + plt.savefig( + os.path.join('./results/' + save_name + '/', + "gt_" + str(cnt) + ".pdf"), bbox_inches='tight', pad_inches=0) + plt.close() + + plt.axis('off') + plt.scatter(x=fx[0, :, 0], y=fx[0, :, 1], + c=out[0, :], cmap='coolwarm') + plt.colorbar() + plt.clim(0, 1000) + plt.savefig( + os.path.join('./results/' + save_name + '/', + "pred_" + str(cnt) + ".pdf"), bbox_inches='tight', pad_inches=0) + plt.close() + + plt.axis('off') + plt.scatter(x=fx[0, :, 0], y=fx[0, :, 1], + c=((y[0, :] - out[0, :])), cmap='coolwarm') + plt.clim(-8, 8) + plt.colorbar() + plt.savefig( + os.path.join('./results/' + save_name + '/', + "error_" + str(cnt) + ".pdf"), bbox_inches='tight', pad_inches=0) + plt.close() + + rel_err /= ntest + print(f"rel_err : {rel_err}") + else: + print('Training...') + for ep in range(args.epochs): + + model.set_train() + train_loss = 0 + + b_i, total_steps = 0, step_per_epoch + for pos, _, y in (pbar := tqdm(train_loader.create_tuple_iterator(), total=total_steps)): + loss = train_step(pos, y) + train_loss += loss.asnumpy() + if (b_i + 1) % 5 == 0 or (b_i + 1) == total_steps: + pbar.set_description(f'Epoch {ep+1} Iter {b_i+1}/{total_steps} Loss {loss.asnumpy():.5f}') + b_i += 1 + + train_loss = train_loss / ntrain + print(f"Epoch {ep + 1} Train loss : {train_loss:.5f}") + + model.set_train(False) + rel_err = 0.0 + for pos, _, y in tqdm(test_loader): + out = model(pos, None).squeeze(-1) + out = y_normalizer.decode(out) + tl = myloss(out, y).asnumpy() + rel_err += tl + + rel_err /= ntest + print(f"rel_err : {rel_err}") + + if (ep + 1) == 1 or (ep + 1)% 5 == 0: + if not os.path.exists('./checkpoints'): + os.makedirs('./checkpoints') + print('save model') + ms.save_checkpoint(model, os.path.join('./checkpoints', save_name + '.ckpt')) + + if not os.path.exists('./checkpoints'): + os.makedirs('./checkpoints') + print('save model') + ms.save_checkpoint(model, os.path.join('./checkpoints', save_name + '.ckpt')) + + +if __name__ == "__main__": + main() diff --git a/MindFlow/applications/data_driven/transolver/src/datasets/dataset.py b/MindFlow/applications/data_driven/transolver/src/datasets/dataset.py new file mode 100644 index 000000000..da2d56461 --- /dev/null +++ b/MindFlow/applications/data_driven/transolver/src/datasets/dataset.py @@ -0,0 +1,27 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"dataset" +class RandomAccessDataset: + """A dataset class that provides random access to data samples.""" + def __init__(self, x, fx, y): + self.x = x + self.fx = fx + self.y = y + + def __len__(self): + return self.x.shape[0] + + def __getitem__(self, idx): + return self.x[idx], self.fx[idx], self.y[idx] diff --git a/MindFlow/applications/data_driven/transolver/src/models/embedding.py b/MindFlow/applications/data_driven/transolver/src/models/embedding.py new file mode 100644 index 000000000..e568e4aba --- /dev/null +++ b/MindFlow/applications/data_driven/transolver/src/models/embedding.py @@ -0,0 +1,32 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"embedding" +from mindspore import ops + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + """ + + half = dim // 2 + freqs = ops.exp( + -ops.log(max_period) * ops.arange(start=0, end=half, dtype=ops.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = ops.cat([ops.cos(args), ops.sin(args)], dim=-1) + if dim % 2: + embedding = ops.cat([embedding, ops.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/MindFlow/applications/data_driven/transolver/src/models/model_dict.py b/MindFlow/applications/data_driven/transolver/src/models/model_dict.py new file mode 100644 index 000000000..aa0665147 --- /dev/null +++ b/MindFlow/applications/data_driven/transolver/src/models/model_dict.py @@ -0,0 +1,23 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"model dict" +from src.models.transolver_irregular_mesh import TransolverIrregular + + +def get_model(model_name): + model_dict = { + 'TransolverIrregular': TransolverIrregular, + } + return model_dict[model_name] diff --git a/MindFlow/applications/data_driven/transolver/src/models/physics_attention.py b/MindFlow/applications/data_driven/transolver/src/models/physics_attention.py new file mode 100644 index 000000000..b1790c28c --- /dev/null +++ b/MindFlow/applications/data_driven/transolver/src/models/physics_attention.py @@ -0,0 +1,75 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"physics attention" +import mindspore as ms +from mindspore import nn +from mindspore import ops +from mindspore import mint +from mindspore.common.initializer import Orthogonal + +class PhysicsAttentionIrregularMesh(nn.Cell): + """for irregular meshes in 1D, 2D or 3D space.""" + def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64): + super().__init__() + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.scale = dim_head ** -0.5 + self.softmax = nn.Softmax(axis=-1) + self.dropout = nn.Dropout(p=dropout) + self.temperature = ms.Parameter(ops.ones([1, heads, 1, 1]) * 0.5) + + self.in_project_x = nn.Linear(dim, inner_dim) + self.in_project_fx = nn.Linear(dim, inner_dim) + self.in_project_slice = nn.Linear(dim_head, slice_num, weight_init=Orthogonal()) + self.to_q = nn.Linear(dim_head, dim_head, bias=False) + self.to_k = nn.Linear(dim_head, dim_head, bias=False) + self.to_v = nn.Linear(dim_head, dim_head, bias=False) + self.to_out = nn.SequentialCell( + nn.Linear(inner_dim, dim), + nn.Dropout(p=dropout) + ) + + def construct(self, x): + """construct""" + # b_size num C + b_size, num, _ = x.shape + + ### (1) Slice + fx_mid = self.in_project_fx(x).reshape(b_size, num, self.heads, self.dim_head) \ + .permute(0, 2, 1, 3).contiguous() # b_size H num C + x_mid = self.in_project_x(x).reshape(b_size, num, self.heads, self.dim_head) \ + .permute(0, 2, 1, 3).contiguous() # b_size H num C + slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # b_size H num G + slice_norm = slice_weights.sum(2) # b_size H G + slice_token = mint.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) + slice_norm = ops.repeat_interleave((slice_norm + 1e-5)[:, :, :, None], + repeats=self.dim_head, axis=-1) + slice_token = slice_token / slice_norm + + ### (2) Attention among slice tokens + q_slice_token = self.to_q(slice_token) + k_slice_token = self.to_k(slice_token) + v_slice_token = self.to_v(slice_token) + dots = ops.matmul(q_slice_token, k_slice_token.transpose((0, 1, -1, -2))) * self.scale + attn = self.softmax(dots) + attn = self.dropout(attn) + out_slice_token = ops.matmul(attn, v_slice_token) # b_size H G D + + ### (3) Deslice + out_x = mint.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) + out_x = ops.permute(out_x, (0, 2, 1, 3)) + out_x = out_x.reshape(out_x.shape[0], out_x.shape[1], -1) + return self.to_out(out_x) diff --git a/MindFlow/applications/data_driven/transolver/src/models/transolver_irregular_mesh.py b/MindFlow/applications/data_driven/transolver/src/models/transolver_irregular_mesh.py new file mode 100644 index 000000000..62f60e79f --- /dev/null +++ b/MindFlow/applications/data_driven/transolver/src/models/transolver_irregular_mesh.py @@ -0,0 +1,178 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""transolver irregular mesh""" +import mindspore as ms +from mindspore import nn +from mindspore import ops +from mindspore.common.initializer import initializer, TruncatedNormal +import numpy as np + +from src.models.physics_attention import PhysicsAttentionIrregularMesh +from src.models.embedding import timestep_embedding + + +ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1), + 'ELU': nn.ELU, 'silu': nn.SiLU} + + +class MLP(nn.Cell): + """MLP""" + def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True): + super().__init__() + + if act in ACTIVATION: + act = ACTIVATION[act] + else: + raise NotImplementedError + self.n_input = n_input + self.n_hidden = n_hidden + self.n_output = n_output + self.n_layers = n_layers + self.res = res + self.linear_pre = nn.SequentialCell(nn.Linear(n_input, n_hidden), act()) + self.linear_post = nn.Linear(n_hidden, n_output) + self.linears = nn.CellList([nn.SequentialCell(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)]) + + def construct(self, x): + """construct""" + x = self.linear_pre(x) + for i in range(self.n_layers): + if self.res: + x = self.linears[i](x) + x + else: + x = self.linears[i](x) + x = self.linear_post(x) + return x + + +class TransolverBlock(nn.Cell): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + dropout: float, + act='gelu', + mlp_ratio=4, + last_layer=False, + out_dim=1, + slice_num=32, + ): + super().__init__() + self.last_layer = last_layer + self.ln_1 = nn.LayerNorm([hidden_dim]) + self.attn = PhysicsAttentionIrregularMesh(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, + dropout=dropout, slice_num=slice_num) + self.ln_2 = nn.LayerNorm([hidden_dim]) + self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) + if self.last_layer: + self.ln_3 = nn.LayerNorm([hidden_dim]) + self.mlp2 = nn.Linear(hidden_dim, out_dim) + + def construct(self, fx): + fx = self.attn(self.ln_1(fx)) + fx + fx = self.mlp(self.ln_2(fx)) + fx + if self.last_layer: + fx = self.mlp2(self.ln_3(fx)) + return fx + + +class TransolverIrregular(nn.Cell): + """TransolverIrregular""" + def __init__(self, + space_dim=1, + n_layers=5, + n_hidden=256, + dropout=0.0, + n_head=8, + time_input=False, + act='gelu', + mlp_ratio=1, + fun_dim=1, + out_dim=1, + slice_num=32, + ref=8, + unified_pos=False + ): + super().__init__() + self.ref = ref + self.unified_pos = unified_pos + self.time_input = time_input + self.n_hidden = n_hidden + self.space_dim = space_dim + if self.unified_pos: + self.preprocess = MLP(fun_dim + self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) + else: + self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) + if time_input: + self.time_fc = nn.Sequential(nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden)) + + self.blocks = nn.CellList([TransolverBlock(num_heads=n_head, hidden_dim=n_hidden, + dropout=dropout, + act=act, + mlp_ratio=mlp_ratio, + out_dim=out_dim, + slice_num=slice_num, + last_layer=i == (n_layers - 1)) + for i in range(n_layers)]) + self.initialize_weights() + self.placeholder = ms.Parameter((1 / (n_hidden)) * ops.rand(n_hidden, dtype=ms.float32)) + + def initialize_weights(self): + """init""" + self.apply(self._init_weights) + + def _init_weights(self, m): + """init""" + if isinstance(m, nn.Linear): + m.weight.set_data(initializer(TruncatedNormal(sigma=0.02), m.weight.shape, m.weight.dtype)) + if isinstance(m, nn.Linear) and m.bias is not None: + m.bias.set_data(initializer(0, m.bias.shape, m.bias.dtype)) + + def get_grid(self, x, batchsize=1): + """get grid""" + # x: B N 2 + # grid_ref + gridx = ms.Tensor(np.linspace(0, 1, self.ref), dtype=ms.float32) + gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1]) + gridy = ms.Tensor(np.linspace(0, 1, self.ref), dtype=ms.float32) + gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1]) + grid_ref = ops.cat((gridx, gridy), dim=-1).reshape(batchsize, self.ref * self.ref, 2) # B H W 8 8 2 + + pos = ops.sqrt(ops.sum((x[:, :, None, :] - grid_ref[:, None, :, :]) ** 2, dim=-1)). \ + reshape(batchsize, x.shape[1], self.ref * self.ref).contiguous() + return pos + + def construct(self, x, fx, time=None): + """construct""" + if self.unified_pos: + x = self.get_grid(x, x.shape[0]) + if fx is not None: + fx = ops.cat((x, fx), -1) + fx = self.preprocess(fx) + else: + fx = self.preprocess(x) + fx = fx + self.placeholder[None, None, :] + + if time is not None: + time_emb = timestep_embedding(time, self.n_hidden).repeat(1, x.shape[1], 1) + time_emb = self.time_fc(time_emb) + fx = fx + time_emb + + for block in self.blocks: + fx = block(fx) + + return fx diff --git a/MindFlow/applications/data_driven/transolver/src/utils/normalizer.py b/MindFlow/applications/data_driven/transolver/src/utils/normalizer.py new file mode 100644 index 000000000..38a40307c --- /dev/null +++ b/MindFlow/applications/data_driven/transolver/src/utils/normalizer.py @@ -0,0 +1,49 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""normalizer""" +import numpy as np +from mindspore import Tensor + +class UnitTransformer(): + """unit transformer""" + def __init__(self, x): + self.mean = Tensor(np.mean(x, axis=(0, 1), keepdims=True)) + self.std = Tensor(np.std(x, axis=(0, 1), keepdims=True) + 1e-8) + + def encode(self, x): + """encode""" + x = (x - self.mean.asnumpy()) / (self.std.asnumpy()) + return x + + def decode(self, x): + """decode""" + return x * self.std + self.mean + + def transform(self, x, inverse=True, component='all'): + """transform""" + if component in ('all', 'all-reduce'): + if inverse: + orig_shape = x.shape + out = (x * (self.std - 1e-8) + self.mean).view(orig_shape) + else: + out = (x - self.mean) / self.std + else: + if inverse: + orig_shape = x.shape + out = (x * (self.std[:, component] - 1e-8) + \ + self.mean[:, component]).view(orig_shape) + else: + out = (x - self.mean[:, component]) / self.std[:, component] + return out diff --git a/MindFlow/applications/data_driven/transolver/src/utils/testloss.py b/MindFlow/applications/data_driven/transolver/src/utils/testloss.py new file mode 100644 index 000000000..5d65f9bce --- /dev/null +++ b/MindFlow/applications/data_driven/transolver/src/utils/testloss.py @@ -0,0 +1,65 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""test loss""" +from mindspore import ops + + +class TestLoss: + """test loss""" + def __init__(self, d=2, p=2, size_average=True, reduction=True): + super().__init__() + + assert d > 0 and p > 0 + + self.d = d + self.p = p + self.reduction = reduction + self.size_average = size_average + + def abs(self, x, y): + """abs""" + num_examples = x.size()[0] + + h = 1.0 / (x.size()[1] - 1.0) + + all_norms = (h ** (self.d / self.p)) * ops.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p, + 1) + + if self.reduction: + if self.size_average: + all_norms = ops.mean(all_norms) + else: + all_norms = ops.sum(all_norms) + + return all_norms + + def rel(self, x, y): + """rel""" + num_examples = x.shape[0] + + diff_norms = ops.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) + y_norms = ops.norm(y.reshape(num_examples, -1), self.p, 1) + if self.reduction: + if self.size_average: + out = ops.mean(diff_norms / y_norms) + else: + out = ops.sum(diff_norms / y_norms) + else: + out = diff_norms / y_norms + return out + + def __call__(self, x, y): + """call""" + return self.rel(x, y) -- Gitee