diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/README.md b/MindFlow/applications/data_mechanism_fusion/p2c2net/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9e63f301dc4a34d4e20827a7ef7ce449e875541d --- /dev/null +++ b/MindFlow/applications/data_mechanism_fusion/p2c2net/README.md @@ -0,0 +1,68 @@ +ENGLISH | [简体中文](README_CN.md) + +# Solving 2d Burgers Equation by Using P2C2Net + +## Overview + +**P2C2Net (PDE-Preserved Coarse Correction Network)** is a novel neural network architecture designed to efficiently solve spatiotemporal partial differential equations (PDEs) on coarse mesh grids with limited training data. Original paper is *"P2C2Net: PDE-Preserved Coarse Correction Network for Efficient Prediction of Spatiotemporal Dynamics"*. +![model architecture](images/model_architecture.png) +The model consists of two synergistic modules: (1) a trainable PDE block that learns to update the coarse solution (i.e., the system state), based on a high-order numerical scheme with boundary condition encoding, and (2) a neural network block that consistently corrects the solution on the fly. In particular, the model adopts a learnable symmetric Conv filter, with weights shared over the entire model, to accurately estimate the spatial derivatives of PDE based on the neural-corrected system state. + +The Burgers’ equation is a nonlinear PDE that models the propagation and reflection of shock waves. It is widely used in fluid mechanics, nonlinear acoustics, gas dynamics, and other fields. In this project, we focus on solving the **2D Burgers’ equation** efficiently using P2C2Net. + +--- + +## Quick Start + +### 1. Data Generation +First, generate training and testing data by running: + +```shell +cd src +python dataGen.py +``` + +### 2. Training +Run the following command to train P2C2Net on the generated data: + +```shell +python p2c2net/train_burgers.py --experiment p2c2net +``` + +where + +`--experiment` is the the experiment directory. It should include experiment specifications under 'config/'; + +`--mode` is the running mode. 'GRAPH' indicates static graph mode. 'PYNATIVE' indicates dynamic graph mode. Default 'GRAPH'; + +`--device_target` represents the type of computing platform used, which can be selected as 'Ascend' or 'GPU', default 'Ascend'; + +`--device_id` represents the calculation card number used, which can be filled in according to the actual situation, default 0; + +`--continue` represents whether to resume training from a saved checkpoint, default False; + +`--config_filename` is the name of the configuration file (under the `configs/` directory) that defines experiment settings such as model parameters, training schedule, default 'burgers.json'; + +`--train_stage` specifies whether to enable the training mode, default True; + +`--test_stage` specifies whether to enable the testing mode, default True; + +### 3. Result +After training, experiment outputs (checkpoints and evaluation results) are saved in result directory under the --experiment directory you provided. Use the saved checkpoints to reproduce evaluations or continue training. + +#### running log +![runtime log](images/runtime.png) + +#### inference result +![inference result](images/inference.png) + +## Rquirements +1. Python>=3.9 +2. MindSpore>=2.5 +3. MindFlow>=0.3.0 + +## Contributor + +gitee id:[liuguangyuu](https://gitee.com/liuguangyuu) + +email: liuguangyuu@outllook.com \ No newline at end of file diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/README_CN.md b/MindFlow/applications/data_mechanism_fusion/p2c2net/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..c509b351c8978a09beb6e5013dfe6e62ee83798f --- /dev/null +++ b/MindFlow/applications/data_mechanism_fusion/p2c2net/README_CN.md @@ -0,0 +1,68 @@ +[ENGLISH](README.md) | 简体中文 + +# P2C2Net求解二维Burgers方程 + +## 概述 + +**P2C2Net(PDE-Preserved Coarse Correction Network**是一种新型神经网络架构,旨在在粗网格和有限训练数据条件下高效求解时空偏微分方程(PDE)。其原始论文为 《P2C2Net: PDE-Preserved Coarse Correction Network for Efficient Prediction of Spatiotemporal Dynamics》。 +![模型架构](images/model_architecture.png) +如上图所示,该模型由两个协同模块组成:(1) 可训练的PDE模块:基于高阶数值格式并结合边界条件编码,学习更新粗网格解;(2) 神经网络校正模块:在预测过程中对解进行动态一致的修正。特别地,P2C2Net采用了一种可学习的对称卷积滤波器,其权重在整个模型中共享,可基于神经网络校正后的系统状态精确估计PDE的空间导数。 + +Burgers 方程是一类非线性偏微分方程,用于描述激波的传播与反射,广泛应用于流体力学、非线性声学、气体动力学等领域。在本项目中,我们重点研究如何利用 P2C2Net 高效求解**二维 Burgers**方程。 + +## 快速开始 +--- + +### 1. 数据生产 +首先运行以下命令以生成训练和测试数据: + +```shell +cd src +python dataGen.py +``` + +### 2. Training +运行以下命令在生成的数据上训练 P2C2Net: + +```shell +python p2c2net/train_burgers.py --experiment p2c2net +``` + +where + +`--experiment` 是实验目录,应包含位于'config/'下的实验配置文件; + +`--mode` 是运行模式. 'GRAPH' 表示静态图模式. 'PYNATIVE' 表示动态图模式. 详见[MindSpore官网](https://www.mindspore.cn/docs/zh-CN/r2.0/design/dynamic_graph_and_static_graph.html?highlight=pynative),默认值'GRAPH'; + +`--device_target` 表示所使用的计算平台类型,可选 'Ascend' 或 'GPU',默认值为 'Ascend'; + +`--device_id` 表示所使用的计算卡编号,默认值为 0; + +`--continue` 表示是否从已有的检查点恢复训练,默认值为 False; + +`--config_filename` 是配置文件的文件名 (位于 `configs/` 目录下) ,其中定义了实验设置,如模型参数、训练参数等,默认值为 'burgers.json'; + +`--train_stage` 表示是否开启训练模式,默认值为 True; + +`--test_stage` 表示是否开启测试模式,默认值为 True; + +### 3. 结果 +训练完成后,实验输出(检查点和评估结果)将保存在你提供的 --experiment 目录下的 result 文件夹中。可使用保存的检查点进行复现评估或继续训练。 + +#### 运行日志 +![运行日志](images/runtime.png) + +#### 推理结果 +![推理结果](images/inference.png) + + +## 环境依赖 +1. Python>=3.9 +2. MindSpore>=2.5 +3. MindFlow>=0.3.0 + +## Contributor + +gitee id:[liuguangyuu](https://gitee.com/liuguangyuu) + +email: liuguangyuu@outllook.com \ No newline at end of file diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/config/burgers.json b/MindFlow/applications/data_mechanism_fusion/p2c2net/config/burgers.json new file mode 100644 index 0000000000000000000000000000000000000000..704faabead79ee291d8ace4b388873839d53267b --- /dev/null +++ b/MindFlow/applications/data_mechanism_fusion/p2c2net/config/burgers.json @@ -0,0 +1,33 @@ +{ + "milestone_num": 10, + "epochs": 500, + "learning_rate": 0.005, + "weight_decay": 0.5, + "save_every": 25, + "gamma": 0.96, + "num_data": 15, + "num_train": 5, + "num_test": 10, + "train_window": 20, + "pretrain_window": 0, + "pretrain_iters": 0, + "timesteps": 400, + "size": 104, + "delta_t": 0.001, + "model": { + "in_channels": 1, + "out_channels": 1, + "viscosity": 0.005, + "modes": 12, + "hidden_channels": 12, + "projection_channels": 50, + "depths": 2 + }, + "batch_size": 100, + "down": 4, + "time_down": 1, + "nolap": true, + "drop": true, + "inferstep": 1400, + "normalization": false +} \ No newline at end of file diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/images/inference.png b/MindFlow/applications/data_mechanism_fusion/p2c2net/images/inference.png new file mode 100644 index 0000000000000000000000000000000000000000..320cec12a42e8756c2c51bb43de27f9905d709e4 Binary files /dev/null and b/MindFlow/applications/data_mechanism_fusion/p2c2net/images/inference.png differ diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/images/model_architecture.png b/MindFlow/applications/data_mechanism_fusion/p2c2net/images/model_architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..f538671f0a84de9e0e0d855f5ae72694ef104cac Binary files /dev/null and b/MindFlow/applications/data_mechanism_fusion/p2c2net/images/model_architecture.png differ diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/images/runtime.png b/MindFlow/applications/data_mechanism_fusion/p2c2net/images/runtime.png new file mode 100644 index 0000000000000000000000000000000000000000..d30a49d87eb09f2ad236db391e784fb0532761b4 Binary files /dev/null and b/MindFlow/applications/data_mechanism_fusion/p2c2net/images/runtime.png differ diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/src/data.py b/MindFlow/applications/data_mechanism_fusion/p2c2net/src/data.py new file mode 100644 index 0000000000000000000000000000000000000000..f426173bc457b4b7d1297c8d5133da0c750176fe --- /dev/null +++ b/MindFlow/applications/data_mechanism_fusion/p2c2net/src/data.py @@ -0,0 +1,251 @@ +import matplotlib.pyplot as plt +from mindspore.dataset import GeneratorDataset +import random +from mindspore import Tensor +import numpy as np +import mindspore.ops as ops +import os +import scipy.io +import pandas as pd + +def generate_dataset(data, train_win, icnum, nolap=True): + + gap = train_win if nolap else 8 + start = [i for i in range(0, len(data[0]) - train_win + 1, gap)] + + random.shuffle(start) + train_set = [] + for i in start: + for j in range(icnum): + train_set.append(data[j][i:i + train_win]) + all_shuffle = [] + rc = [i for i in range(len(train_set))] + random.shuffle(rc) + for i in rc: + all_shuffle.append(train_set[i]) + all_shuffle = np.array(all_shuffle) + return all_shuffle + + +def plot_loss(train_loss, save_dir): + iter = [i for i in range(1, len(train_loss) + 1)] + plt.plot(iter, train_loss, color="red", label="train_loss") + plt.title("Loss_iters_burgers", fontsize=24) + plt.xlabel("iters", fontsize=14) + plt.ylabel("loss", fontsize=14) + plt.tick_params(axis="both", labelsize=14) + plt.legend(fontsize=16) + plt.savefig(save_dir + "/train_burgers_loss.png", dpi=600) + plt.close() + print("plot burgers loss over") + + + +class MyDataset: + def __init__(self, data_features, eps=1.0e-5): + self.len = len(data_features) + self.features = data_features + self.eps = eps + + + def __getitem__(self, index): + feature = self.features[index] + x = feature[0:1, ...] + seq = feature[1:, ...] + return x, seq + + def __len__(self): + return self.len + +class UnitGaussianNormalizer(): + def __init__(self, x, normalization_dim = [], eps=1.0e-5): + super(UnitGaussianNormalizer, self).__init__() + self.mean = np.mean(x, axis=normalization_dim, keepdims=True) + self.std = np.std(x, axis=normalization_dim, keepdims=True) + print("mean, std", self.mean, self.std) + self.eps = eps + + def encode(self, x): + x = (x - self.mean) / (self.std + self.eps) + return x + + + def decode(self, x): + std = self.std + self.eps + mean = self.mean + x = x * std + mean + return x + + +def get_data( + exp_dir, + num=5, + down=4, +): + all_data = [] + for i in range(1, num + 1): + data_dir = os.path.join(exp_dir, 'data', 'Burgers_2101x2x104x104_[RK4,R=200,dt=0_001,#seed' + str(i) + '].mat') + cur_UV = scipy.io.loadmat(data_dir)['uv'] + cur_uv = np.ascontiguousarray(cur_UV[100:2101, ...], dtype=np.float32) + cur_truth = cur_uv[:, :, ::down, ::down] + all_data.append(cur_truth) + all_data = np.array(all_data) # [5, 2000, 2, 26, 26] + + return all_data + + +def infer(model, init, infersteps, compute_dtype): + model.steps = infersteps + init = Tensor(init, dtype=compute_dtype) + outputs_uv = model(init) + return outputs_uv + +def evl_error(test_data, model, inferstep, delta_t, resolution, err_save_dir, compute_dtype, normalizer=None, ploterrorprop=True, plotsnapshots=True): + + pred = [] + ground_truth = [] + for i, data in enumerate(test_data): + init = data[0:1] + truth = data[1:inferstep] + output = infer(model, init, inferstep, compute_dtype) + output = output.asnumpy() + if normalizer: + output = normalizer.decode(output) + truth = normalizer.decode(truth) + output = output.squeeze() + truth = truth.squeeze() + + if plotsnapshots: + print("snapshots ploting") + plot_snapshots(output[-1, :, :, :], truth[-1, :, :, :], resolution, os.path.join(err_save_dir, f"test_set_{i+1}_burgers-snap")) + + + if ploterrorprop: + print("error propagation ploting") + uv_truth = truth.transpose(1, 0, 2, 3) + uv_net = output.transpose(1, 0, 2, 3) + plot_err_prop(uv_truth, uv_net, delta_t, os.path.join(err_save_dir, f"test_set_{i+1}_burgers-err-propa.png")) + + + pred.append(np.expand_dims(output.transpose(1, 0, 2, 3), axis=0)) + ground_truth.append(np.expand_dims(truth.transpose(1, 0, 2, 3), axis=0)) + + pred = np.concatenate(pred, axis=0) + ground_truth = np.concatenate(ground_truth, axis=0) + rmse = np.sqrt(np.mean(np.sum((ground_truth - pred)**2, axis=1))) + mae = np.mean(np.abs(pred - ground_truth)) + truth_norms = np.linalg.norm(ground_truth, axis=1) + mnad = np.mean(np.linalg.norm(ground_truth - pred, axis=1) / (np.max(truth_norms) - np.min(truth_norms))) + hct = 0 + for i in range(ground_truth.shape[0]): + pcc = np.corrcoef(ground_truth[i].flatten(), pred[i].flatten())[0, 1] + if not np.isnan(pcc) and pcc > 0.8: + hct += delta_t + + print("rmse, mas, mnad, hct", rmse, mae, mnad, hct) + error_data = { + 'Metric': ['RMSE', 'MAE', 'MNAD', 'HCT'], + 'Value': [rmse, mae, mnad, hct] + } + df = pd.DataFrame(error_data) + df.to_csv(os.path.join(err_save_dir, 'evaluation_metrics.csv'), index=False) + + return rmse, mae, mnad, hct + +def plot_snapshots(UV_TRUTH, UV_Net, resolution, fig_save_dir): + z1 = UV_TRUTH[0, :, :] + z2 = UV_Net[0, :, :] + z3 = UV_TRUTH[1, :, :] + z4 = UV_Net[1, :, :] + xx = np.linspace(0, 1, resolution) + yy = np.linspace(0, 1, resolution) + x, y = np.meshgrid(xx, yy) + + h1 = np.max([np.max(z1), np.max(z2)]) + l1 = np.min([np.min(z1), np.min(z2)]) + level1 = np.linspace(l1, h1, 10, endpoint=True) + h2 = np.max([np.max(z3), np.max(z4)]) + l2 = np.min([np.min(z3), np.min(z4)]) + level2 = np.linspace(l2, h2, 10, endpoint=True) + + fig, axs = plt.subplots(2, 2, figsize=(18, 12)) + ax1, ax2, ax3, ax4 = axs.flatten() + size = 24 + fig.subplots_adjust(left=0.1, right=0.9, wspace=0.25, hspace=0.4) + + cf1 = ax1.contourf(x, y, z1, level1, cmap='coolwarm') + ax1.set_title(r'$Ref.$', fontsize=size+8) + ax1.text(-0.2, 0.5, r"$u$", + ha='center', + va='center', + fontsize=size+12, + transform=ax1.transAxes) + ax1.set_xticks([]) + ax1.set_yticks([]) + ax1.set_aspect('equal') + + + + cf2 = ax2.contourf(x, y, z2, level1, cmap='coolwarm') + ax2.set_title(r'$P^2C^2Net$', fontsize=size+8) + ax2.set_xticks([]) + ax2.set_yticks([]) + ax2.set_aspect('equal') + + cf3 = ax3.contourf(x, y, z3, level2, cmap='coolwarm') + ax3.set_title(r'$Ref.$', fontsize=size+8) + ax3.text(-0.25, 0.5, r"$v$", + ha='center', + va='center', + fontsize=size+12, + transform=ax3.transAxes) + ax3.set_xticks([]) + ax3.set_yticks([]) + ax3.set_aspect('equal') + + + cf4 = ax4.contourf(x, y, z4, level2, cmap='coolwarm') + ax4.set_title(r'$P^2C^2Net$', fontsize=size+8) + ax4.set_xticks([]) + ax4.set_yticks([]) + ax4.set_aspect('equal') + + plt.savefig(fig_save_dir, dpi=600) + plt.clf() + plt.close() + +def plot_err_prop(UV_TRUTH, UV_Net, dt, fig_save_dir): + # [2,steps,26,26] + eps = 1e-4 + + MSE = np.mean((UV_TRUTH - UV_Net) ** 2, axis=(0, 2, 3)) + # 累计误差 + accum = np.array([[i * dt, eps + np.sqrt(MSE[:i + 1].mean())] for i in range(0, UV_Net.shape[1], 1)]) + fig = plt.figure(figsize=(6, 4)) + ax = fig.add_axes([0.13, 0.12, 0.8, 0.80]) + ax.plot(accum[:, 0], accum[:, 1], alpha=0.8, linewidth=2, color='black', label=r'$P^2C^2Net$') + + ax.set_xlim([0, 1.4]) + ax.set_ylim([eps, 1e0]) + ax.set_xticks([0.0, 0.7, 1.4]) + ax.set_yticks([eps, 1e-1]) + ax.set_yscale('log') + ax.set_ylabel(r'a-RMSE', fontsize=14) + ax.set_xlabel('t(s)', fontsize=14, labelpad=-0.0) + ax.tick_params(labelsize=14, direction='in') + ax.set_title('Error propagation', fontsize=16) + plt.legend() + + plt.savefig(fig_save_dir, dpi=600) + plt.close() + + +def cal_reffe(output, truth): + nume = np.linalg.norm(output - truth) + deno = np.linalg.norm(truth) + epsino = nume / deno + return epsino + +def ensure_directories(*dirs): + for dir in dirs: + os.makedirs(dir, exist_ok=True) \ No newline at end of file diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/src/dataGen.py b/MindFlow/applications/data_mechanism_fusion/p2c2net/src/dataGen.py new file mode 100644 index 0000000000000000000000000000000000000000..4e39ca82f7feec7503a6b4ad3b6a814713eb736a --- /dev/null +++ b/MindFlow/applications/data_mechanism_fusion/p2c2net/src/dataGen.py @@ -0,0 +1,315 @@ +'''FD solver for 2d Buergers equation''' +# spatial diff: 4th order laplacian +# temporal diff: O(dt^5) due to RK4 + +import scipy.io +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +# import torch + +def apply_laplacian(mat, dx=1.0): + # dx is inversely proportional to N + """This function applies a discretized Laplacian + in periodic boundary conditions to a matrix + For more information see + https://en.wikipedia.org/wiki/Discrete_Laplace_operator#Implementation_via_operator_discretization + """ + + # the cell appears 4 times in the formula to compute + # the total difference + neigh_mat = -5 * mat.copy() + + # Each direct neighbor on the lattice is counted in + # the discrete difference formula + neighbors = [ + (4 / 3, (-1, 0)), + (4 / 3, (0, -1)), + (4 / 3, (0, 1)), + (4 / 3, (1, 0)), + (-1 / 12, (-2, 0)), + (-1 / 12, (0, -2)), + (-1 / 12, (0, 2)), + (-1 / 12, (2, 0)), + ] + + # shift matrix according to demanded neighbors + # and add to this cell with corresponding weight + for weight, neigh in neighbors: + neigh_mat += weight * np.roll(mat, neigh, (0, 1)) + + return neigh_mat / dx ** 2 + + +def apply_dx(mat, dx=1.0): + ''' central diff for dx''' + + # np.roll, axis=0 -> row + # the total difference + neigh_mat = -0 * mat.copy() + + # Each direct neighbor on the lattice is counted in + # the discrete difference formula + neighbors = [ + (1.0 / 12, (2, 0)), + (-8.0 / 12, (1, 0)), + (8.0 / 12, (-1, 0)), + (-1.0 / 12, (-2, 0)) + ] + + # shift matrix according to demanded neighbors + # and add to this cell with corresponding weight + for weight, neigh in neighbors: + neigh_mat += weight * np.roll(mat, neigh, (0, 1)) + + return neigh_mat / dx + + +def apply_dy(mat, dy=1.0): + ''' central diff for dx''' + + # the total difference + neigh_mat = -0 * mat.copy() + + # Each direct neighbor on the lattice is counted in + # the discrete difference formula + neighbors = [ + (1.0 / 12, (0, 2)), + (-8.0 / 12, (0, 1)), + (8.0 / 12, (0, -1)), + (-1.0 / 12, (0, -2)) + ] + + # shift matrix according to demanded neighbors + # and add to this cell with corresponding weight + for weight, neigh in neighbors: + neigh_mat += weight * np.roll(mat, neigh, (0, 1)) + + return neigh_mat / dy + + +def get_temporal_diff(U, V, R, dx): + # u and v in (h, w) + + laplace_u = apply_laplacian(U, dx) + laplace_v = apply_laplacian(V, dx) + + u_x = apply_dx(U, dx) + v_x = apply_dx(V, dx) + + u_y = apply_dy(U, dx) + v_y = apply_dy(V, dx) + + # governing equation + u_t = (1.0 / R) * laplace_u - U * u_x - V * u_y + v_t = (1.0 / R) * laplace_v - U * v_x - V * v_y + + return u_t, v_t + + +def update(U0, V0, R=100.0, dt=0.05, dx=1.0): + u_t, v_t = get_temporal_diff(U0, V0, R, dx) + + U = U0 + dt * u_t + V = V0 + dt * v_t + return U, V + + +def update_rk4(U0, V0, R=100.0, dt=0.05, dx=1.0): + """Update with Runge-kutta-4 method + See https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods + """ + ############# Stage 1 ############## + # compute the diffusion part of the update + + u_t, v_t = get_temporal_diff(U0, V0, R, dx) + + K1_u = u_t + K1_v = v_t + + ############# Stage 1 ############## + U1 = U0 + K1_u * dt / 2.0 + V1 = V0 + K1_v * dt / 2.0 + + u_t, v_t = get_temporal_diff(U1, V1, R, dx) + + K2_u = u_t + K2_v = v_t + + ############# Stage 2 ############## + U2 = U0 + K2_u * dt / 2.0 + V2 = V0 + K2_v * dt / 2.0 + + u_t, v_t = get_temporal_diff(U2, V2, R, dx) + + K3_u = u_t + K3_v = v_t + + ############# Stage 3 ############## + U3 = U0 + K3_u * dt + V3 = V0 + K3_v * dt + + u_t, v_t = get_temporal_diff(U3, V3, R, dx) + + K4_u = u_t + K4_v = v_t + + # Final solution + U = U0 + dt * (K1_u + 2 * K2_u + 2 * K3_u + K4_u) / 6.0 + V = V0 + dt * (K1_v + 2 * K2_v + 2 * K3_v + K4_v) / 6.0 + + return U, V + + +def postProcess(output, reso, xmin, xmax, ymin, ymax, num, fig_save_dir): + ''' num: Number of time step + ''' + + x = np.linspace(0, reso, reso + 1) + y = np.linspace(0, reso, reso + 1) + x_star, y_star = np.meshgrid(x, y) + x_star, y_star = x_star[:-1, :-1], y_star[:-1, :-1] + + u_pred = output[num, 0, :, :] + v_pred = output[num, 1, :, :] + + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(6, 3)) + fig.subplots_adjust(hspace=0.3, wspace=0.3) + + cf = ax[0].scatter(x_star, y_star, c=u_pred, alpha=0.95, edgecolors='none', cmap='RdYlBu', + marker='s', s=3, vmin=-1, vmax=1) + ax[0].axis('square') + ax[0].set_xlim([xmin, xmax]) + ax[0].set_ylim([ymin, ymax]) + # cf.cmap.set_under('black') + # cf.cmap.set_over('whitesmoke') + ax[0].set_xticks([]) + ax[0].set_yticks([]) + ax[0].set_title('u-FDM') + fig.colorbar(cf, ax=ax[0], fraction=0.046, pad=0.04) + + cf = ax[1].scatter(x_star, y_star, c=v_pred, alpha=0.95, edgecolors='none', cmap='RdYlBu', + marker='s', s=3, vmin=-1, vmax=1) + ax[1].axis('square') + ax[1].set_xlim([xmin, xmax]) + ax[1].set_ylim([ymin, ymax]) + # cf.cmap.set_under('black') + # cf.cmap.set_over('whitesmoke') + ax[1].set_xticks([]) + ax[1].set_yticks([]) + ax[1].set_title('v-FDM') + fig.colorbar(cf, ax=ax[1], fraction=0.046, pad=0.04) + + # plt.draw() + plt.savefig(fig_save_dir + 'uv_%s.png' % str(num).zfill(4)) + plt.close('all') + + +def rand_gaussian_ic(Num_a, Num_b, Nx, Ny, seed,plot=True): + assert Nx == Ny + x, y = [np.linspace(0, 1, Nx + 1)] * 2 + xx, yy = np.meshgrid(x[:-1], y[:-1]) + # print("xx",xx) + # print("yy",yy) + Wx, Wy = [0] * 2 + np.random.seed(seed)#1 + # np.random.seed(2)#2 + # np.random.seed(3)#3 + # np.random.seed(4)#4 + Ax = np.random.normal(0, 1, size=(Num_a, Num_b)) + Bx = np.random.normal(0, 1, size=(Num_a, Num_b)) + Ay = np.random.normal(0, 1, size=(Num_a, Num_b)) + By = np.random.normal(0, 1, size=(Num_a, Num_b)) + cxy = np.random.normal(-1, 1, size=2) + for i in range(Num_a): + for j in range(Num_b): + Wx = Wx + Ax[i, j] * np.sin(2 * np.pi * ((i - Num_a // 2) * xx + (j - Num_b // 2) * yy)) + Bx[ + i, j] * np.cos(2 * np.pi * ((i - Num_a // 2) * xx + (j - Num_b // 2) * yy)) + Wy = Wy + Ay[i, j] * np.sin(2 * np.pi * ((i - Num_a // 2) * xx + (j - Num_b // 2) * yy)) + By[ + i, j] * np.cos(2 * np.pi * ((i - Num_a // 2) * xx + (j - Num_b // 2) * yy)) + # print(cxy) + # print(cxy[0]) + # print(cxy[1]) + Ux = 2 * Wx / Wx.max() + cxy[0] + Uy = 2 * Wy / Wy.max() + cxy[1] + + if plot: + fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(5, 8)) + fig.subplots_adjust(hspace=0.3, wspace=0.1) + # + ax[0].axis('square') + cf = ax[0].contourf(xx, yy, Ux, levels=101, cmap='jet') + ax[0].set_xlim([0, 1]) + ax[0].set_ylim([0, 1]) + ax[0].set_xticks([]) + ax[0].set_yticks([]) + ax[0].set_title(r'$u_0$', ) + fig.colorbar(cf, ax=ax[0], fraction=0.046, pad=0.04) + # + ax[1].axis('square') + cf = ax[1].contourf(xx, yy, Uy, levels=101, cmap='jet') + ax[1].set_xlim([0, 1]) + ax[1].set_ylim([0, 1]) + ax[1].set_xticks([]) + ax[1].set_yticks([]) + ax[1].set_title(r'$v_0$') + fig.colorbar(cf, ax=ax[1], fraction=0.046, pad=0.04) + # + plt.show() + + return Ux, Uy +def createdata(seed): + M, N = 104, 104 + n_simu_steps = 2101 + # dt = 0.00025 # 0.00025 still not converge for FWE, 0.001, 0.002 works for RK4 + dt = 0.001 + dx = 1.0 / M + # R = 120.0 + R = 200 + + print(seed) + U, V = rand_gaussian_ic(Num_a=10, Num_b=10, Nx=104, Ny=104,seed=seed, plot=False) # 2101_1 and 2 + # U, V = rand_gaussian_ic(Num_a=20, Num_b=20, Nx=100, Ny=100, plot=False)#4 + + U, V = U / 3.0, V / 3.0 + + # U_record = U.copy()[None, ...] + # V_record = V.copy()[None, ...] + + U_list = [] + V_list = [] + + for step in range(n_simu_steps): + + # U, V = update(U, V, R, dt, dx) # [h,w] + U, V = update_rk4(U, V, R, dt, dx) # [h,w] + + if (step + 1) % 1 == 0: + print(step, '\n') + U_list.append(U[None, ...]) + V_list.append(V[None, ...]) + + U_record = np.concatenate(U_list, axis=0) # [t,h,w] + V_record = np.concatenate(V_list, axis=0) + + UV = np.concatenate((U_record[None, ...], V_record[None, ...]), axis=0) # (c,t,h,w) + UV = np.transpose(UV, [1, 0, 2, 3]) # (t,c,h,w) (751,2,128,128) + #np.ascontiguousarray(UV, dtype=np.float64) + + # fig_save_dir = './figures/' + # for i in range(0, 2000, 50): # 1500 + # postProcess(UV, M, 0, M, 0, M, i, fig_save_dir) + + # Output result if you want + data_save_dir = 'data/' + + + scipy.io.savemat(data_save_dir + 'Burgers_2101x2x104x104_[RK4,R=200,dt=0_001,#seed'+str(seed)+'].mat', {'uv': UV}) + + print("seed"+str(seed)+"create over") + +if __name__ == '__main__': + # grid size + for seed in range(1, 16): + createdata(seed) + diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/src/model.py b/MindFlow/applications/data_mechanism_fusion/p2c2net/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5ee9cf4000ae19a8834833f48368c2b9a0531773 --- /dev/null +++ b/MindFlow/applications/data_mechanism_fusion/p2c2net/src/model.py @@ -0,0 +1,292 @@ +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor, context, Parameter, value_and_grad, jit +import mindspore.ops as ops +import numpy as np +from mindflow.cell.neural_operators.dft import dft2, idft2 +from mindflow.cell import FNO2D +from mindspore.amp import DynamicLossScaler, all_finite, auto_mixed_precision +import mindspore as ms + +class RCNN(nn.Cell): + def __init__(self, + deno=1/104.0*4, + time_steps=32, + viscosity=1/200.0, + delta_t=0.001, + in_channels=1, + out_channels=1, + resolution=26, + modes=12, + hidden_channels=12, + projection_channels=5, + depths=4, + kernel_size=5, + use_ascend=True, + compute_dtype=mstype.float32 + ): + super(RCNN, self).__init__() + fno_dtype = compute_dtype + if use_ascend: + fno_dtype = mstype.float16 + + + self.u_cor = FNO2D( + in_channels = in_channels, + out_channels = out_channels, + resolutions = resolution, + n_modes = modes, + hidden_channels = hidden_channels, + n_layers = depths, + projection_channels = projection_channels, + data_format = "channels_first", + positional_embedding=False, + fno_compute_dtype = fno_dtype + ) + self.v_cor = FNO2D( + in_channels = in_channels, + out_channels = out_channels, + resolutions = resolution, + n_modes = modes, + hidden_channels = hidden_channels, + n_layers = depths, + projection_channels = projection_channels, + data_format = "channels_first", + positional_embedding=False, + fno_compute_dtype = fno_dtype + ) + + self.filter = Filter( + out_channels=out_channels, + kernel_size=kernel_size, + deno=deno, + compute_dtype=compute_dtype + ) + + self.steps = time_steps + self.vis = viscosity + self.dt = delta_t + + def rk(self, h): + u_prev = h[:, 0:1, ...] + v_prev = h[:, 1:2, ...] + + u_cor = self.u_cor(u_prev) + v_cor = self.v_cor(v_prev) + + + filter_u = self.filter(u_cor) # (du/dx, du/dy, d2u/dx2, d2u/dy2) + filter_v = self.filter(v_cor) # (dv/dx, dv/dy, d2v/dx2, d2v/dy2) + + # vis * (d2u/dx2 + d2u/dy2) - (du/dx * u_prev + du/dy * v_prev) + u_res = (self.vis * (filter_u[:, 2:3, :, :] + filter_u[:, 3:4, :, :]) - ( + u_prev * filter_u[:, 0:1, :, :] + v_prev * filter_u[:, 1:2, :, :])) + + v_res = (self.vis * (filter_v[:, 2:3, :, :] + filter_v[:, 3:4, :, :]) - ( + u_prev * filter_v[:, 0:1, :, :] + v_prev * filter_v[:, 1:2, :, :])) + + return u_res, v_res + + def call_cell(self, h): + u_prev = h[:, 0:1] + v_prev = h[:, 1:2] + k1_u, k1_v = self.rk(h) + u1 = u_prev + k1_u * self.dt / 2.0 + v1 = v_prev + k1_v * self.dt / 2.0 + + k2_u, k2_v = self.rk(ops.concat((u1, v1), axis=1)) + u2 = u_prev + k2_u * self.dt / 2.0 + v2 = v_prev + k2_v * self.dt / 2.0 + + k3_u, k3_v = self.rk(ops.concat((u2, v2), axis=1)) + u3 = u_prev + k3_u * self.dt + v3 = v_prev + k3_v * self.dt + + k4_u, k4_v = self.rk(ops.concat((u3, v3), axis=1)) + u_next = u_prev + (k1_u + 2 * k2_u + 2 * k3_u + k4_u) / 6.0 * self.dt + v_next = v_prev + (k1_v + 2 * k2_v + 2 * k3_v + k4_v) / 6.0 * self.dt + ch = ops.concat((u_next, v_next), axis=1) + return ch + + def construct(self, init_uv): + outputs_uv = [] + internal_uv = init_uv + for step in range(self.steps - 1): + internal_uv = self.call_cell(internal_uv) + if internal_uv.shape[0] > 1: + outputs_uv.append(ops.expand_dims(internal_uv, 0)) + else: # infer + outputs_uv.append(internal_uv) + + outputs_uv = ops.concat(tuple(outputs_uv), axis=0) + return outputs_uv + + +class P2N2Net(nn.Cell): + def __init__(self, model, learning_rate, weight_decay, use_ascend, compute_dtype): + super(P2N2Net, self).__init__() + self.model = model + + self.mse_loss = nn.MSELoss(reduction='mean') + self.optimizer = nn.Adam( + self.model.trainable_params(), + learning_rate=learning_rate, + weight_decay=weight_decay, + + ) + self.grad_fn = value_and_grad( + fn = self.forward, + grad_position = None, + weights = self.model.trainable_params(), + has_aux = False + ) + self.use_ascend = use_ascend + + if self.use_ascend: + auto_mixed_precision(self.model, 'O1') + self.loss_scaler = DynamicLossScaler( + scale_value = 2**10, + scale_factor = 2, + scale_window = 1 + ) + + def forward(self, inputs, targets): + output_uv = self.model(inputs) + logits = ops.transpose(output_uv, (1, 0, 2, 3, 4)) + loss = self.mse_loss(logits, targets) * 1e5 + return loss + + # @jit + def construct(self, batch_data): + inputs = batch_data['data'].squeeze() + targets = batch_data['labels'] + loss, grads = self.grad_fn(inputs, targets) + clipped_grads = ops.clip_by_global_norm(grads, clip_norm=1.0) + if self.use_ascend: + loss = self.loss_scaler.unscale(loss) + is_finite = all_finite(grads) + if is_finite: + grads = self.loss_scaler.unscale(grads) + loss = ops.depend(loss, self.optimizer(grads)) + self.loss_scaler.adjust(is_finite) + else: + loss = ops.depend(loss, self.optimizer(grads)) + return loss + + + + + +class Filter(nn.Cell): + def __init__(self, out_channels, kernel_size, deno, compute_dtype): + super(Filter, self).__init__() + self.dx = Conv( + out_channels=out_channels, + kernel_size=kernel_size, + order=1, + deno=deno, + compute_dtype=compute_dtype + ) + self.dxx = Conv( + out_channels=out_channels, + kernel_size=kernel_size, + order=2, + deno=deno, + compute_dtype=compute_dtype + ) + + def padMethod(self, x): + x_pad = ops.concat((x[:, :, :, -2:], x, x[:, :, :, :2]), axis=3) + x_pad = ops.concat((x_pad[:, :, -2:, :], x_pad, x_pad[:, :, :2, :]), axis=2) + return x_pad + + def construct(self, x): + x = self.padMethod(x) + dx = self.dx(x) + dy = ops.transpose(self.dx(ops.transpose(x, (0, 1, 3, 2))), (0, 1, 3, 2)) + dxx = self.dxx(x) + dyy = ops.transpose(self.dxx(ops.transpose(x, (0, 1, 3, 2))), (0, 1, 3, 2)) + res = ops.concat((dx, dy, dxx, dyy), axis=1) + return res + + +class Conv(nn.Cell): + def __init__(self, out_channels, order, deno, compute_dtype, kernel_size=5): + super(Conv, self).__init__() + self.deno = deno + self.order = order + self.conv2d = ops.Conv2D( + out_channel=out_channels, + kernel_size=5, + data_format="NCHW" + ) + if self.order == 2: + self.deno = self.deno ** 2 + + self.matrix_3 = ms.Parameter(Tensor(np.random.randn(3, 3), dtype=compute_dtype), + requires_grad=True) + + def get_kernel(self): + matrix = ops.zeros((5, 5)) + matrix[0, 0] = self.matrix_3[0, 0] + matrix[0, 1] = self.matrix_3[0, 1] + matrix[1, 0] = self.matrix_3[1, 0] + matrix[1, 1] = self.matrix_3[1, 1] + matrix[0, 2] = self.matrix_3[0, 2] + matrix[1, 2] = self.matrix_3[1, 2] + matrix[2, 0] = self.matrix_3[2, 0] + matrix[2, 1] = self.matrix_3[2, 1] + + if self.order == 1: + # 1 + matrix[0, 3] = -matrix[0, 1] + matrix[0, 4] = -matrix[0, 0] + matrix[1, 3] = -matrix[1, 1] + matrix[1, 4] = -matrix[1, 0] + # 2 + matrix[3, 0] = -matrix[1, 0] + matrix[4, 0] = -matrix[0, 0] + matrix[3, 1] = -matrix[1, 1] + matrix[4, 1] = -matrix[0, 1] + # 3 + matrix[3, 3] = -matrix[1, 3] + matrix[4, 3] = -matrix[0, 3] + matrix[3, 4] = -matrix[1, 4] + matrix[4, 4] = -matrix[0, 4] + # middle + matrix[3, 2], matrix[4, 2] = -matrix[1, 2], -matrix[0, 2] + matrix[2, 3], matrix[2, 4] = -matrix[2, 1], -matrix[2, 0] + # temp = matrix[2,2] + # matrix[2,2] = -(sum(matrix) - temp) + matrix[2, 2] = 0 + else: + matrix[0, 3] = matrix[0, 1] + matrix[0, 4] = matrix[0, 0] + matrix[1, 3] = matrix[1, 1] + matrix[1, 4] = matrix[1, 0] + # 2 + matrix[3, 0] = matrix[1, 0] + matrix[4, 0] = matrix[0, 0] + matrix[3, 1] = matrix[1, 1] + matrix[4, 1] = matrix[0, 1] + # 3 + matrix[3, 3] = matrix[1, 3] + matrix[4, 3] = matrix[0, 3] + matrix[3, 4] = matrix[1, 4] + matrix[4, 4] = matrix[0, 4] + # middle + matrix[3, 2], matrix[4, 2] = matrix[1, 2], matrix[0, 2] + matrix[2, 3], matrix[2, 4] = matrix[2, 1], matrix[2, 0] + matrix[2, 2] = -( + (matrix[0, 0] + matrix[0, 1] + matrix[1, 0] + matrix[1, 1]) * 4 + ( + matrix[0, 2] + matrix[1, 2] + matrix[2, 0] + matrix[2, 1]) * 2) + + return matrix + + def construct(self, x): + # update matrix + m = self.get_kernel() + + weight = ops.expand_dims(ops.expand_dims(m, 0), 0) + central = self.conv2d(x, weight) / self.deno + return central diff --git a/MindFlow/applications/data_mechanism_fusion/p2c2net/train_burgers.py b/MindFlow/applications/data_mechanism_fusion/p2c2net/train_burgers.py new file mode 100644 index 0000000000000000000000000000000000000000..df583570ab2bbe79377dbf6d0a9308f99f0ceda2 --- /dev/null +++ b/MindFlow/applications/data_mechanism_fusion/p2c2net/train_burgers.py @@ -0,0 +1,382 @@ +import argparse +import os +import time +import numpy as np +from mindspore import context, jit, nn, ops, set_seed, Tensor +import mindspore.common.dtype as mstype +from mindspore.dataset import GeneratorDataset +from src.model import P2N2Net, RCNN +from src.data import * +import sys +import signal +import mindspore as ms +from mindflow.utils import log_timer +from mindspore.amp import DynamicLossScaler, all_finite, auto_mixed_precision +import json + +set_seed(12345) +np.random.seed(12345) + +data_dir = "data" +model_dir = "model" +model_name = "P2C2Net_burgers" +config_dir = "config" +result_dir = "result" +loss_dir = "loss" +error_dir = "error" + +# @log_timer +def train_stage( + train_loader, + pretrain_loader, + experiment_dir, + config, + use_ascend, + continue_from, + compute_dtype, +): + + milestone_num = config["milestone_num"] + epochs = config['epochs'] + weight_decay = config["weight_decay"] + save_every = config["save_every"] + gamma = config['gamma'] + lr = config['learning_rate'] + model_config = config["model"] + down = config["down"] + size = config["size"] + resolution = size // down + deno = 1 / size * down + train_win = config["train_window"] + delta_t = config["delta_t"] + + if milestone_num is not None: + milestone = list([(epochs // milestone_num) * (i + 1) + for i in range(milestone_num)]) + lr = list([gamma**i*float(lr) for i in range(milestone_num)]) + learning_rate = nn.piecewise_constant_lr(milestone, lr) + else: + learning_rate = config['learning_rate'] + + start_epoch = 1 + + model = RCNN( + **model_config, + delta_t=delta_t, + compute_dtype=compute_dtype, + resolution=resolution, + deno=deno, + time_steps=train_win, + use_ascend=use_ascend + ) + print("Model") + print(model) + + model_param = f"e{epochs}_k{model_config['modes']}_d{model_config['depths']}_l{model_config['hidden_channels']}_p{model_config['projection_channels']}" + result_save_dir = os.path.join(experiment_dir, result_dir) + model_save_dir = os.path.join(result_save_dir, model_dir, model_param, model_name) + loss_save_dir = os.path.join(result_save_dir, loss_dir, model_param) + ensure_directories(result_save_dir, model_save_dir, loss_save_dir) + + + if continue_from is not None: + ms.load_checkpoint(model_save_dir + f'checkpoints_{continue_from}.ckpt', model) + start_epoch = continue_from + print(f"training continum from checkpoints {continue_from}") + + net = P2N2Net( + model, + learning_rate, + weight_decay, + use_ascend, + compute_dtype + ) + + model.set_train(True) + if pretrain_loader: + pretrain_iters = config["pretrain_iters"] + pretrain_window = config["pretrain_window"] + print("pretrain start") + model.steps = pretrain_window + for i in range(pretrain_iters): + print("pretrain ", i) + pretrain_batch_count = 0 + pretrain_epoch_loss = 0 + for batch_data in pretrain_loader.create_dict_iterator(): + # [ic, uv, nx, ny] , [step, uv, nx, ny] + pretrain_batch_loss = net(batch_data) + pretrain_epoch_loss += pretrain_batch_loss + pretrain_batch_count += 1 + print(f'batch {pretrain_batch_count} loss {pretrain_batch_loss}') + print("Pretraining epoch Loss: ", pretrain_epoch_loss / pretrain_batch_count) + + start = time.time() + train_loss_list = [] + print("training start") + model.steps = train_win + for epoch in range(start_epoch, 1 + epochs): + print(f'epoch {epoch} start') + epoch_start = time.time() + epoch_loss = 0 + batch_count = 0 + for batch_data in train_loader.create_dict_iterator(): + # [ic, uv, nx, ny] , [step, uv, nx, ny] + batch_loss = net(batch_data) + epoch_loss += batch_loss + batch_count += 1 + print(f'batch {batch_count} loss {batch_loss}') + + epoch_loss = epoch_loss / batch_count + train_loss_list.append(epoch_loss) + epoch_end = time.time() + epoch_time = epoch_end - epoch_start + total_time = epoch_end - start + print("training epoch Loss: ", epoch_loss, "epoch Time: ", epoch_time, "total Time: ", total_time) + + np.savetxt(loss_save_dir + "/train_loss.txt", train_loss_list) + if epoch % save_every == 0 or epoch == epochs: + ms.save_checkpoint( + save_obj=model, + ckpt_file_name = model_save_dir + f"_checkpoints_{epoch}.ckpt" + ) + ms.save_checkpoint( + save_obj=model, + ckpt_file_name = model_save_dir + ".ckpt" + ) + plot_loss(train_loss_list, loss_save_dir) + end = time.time() + total_time = end - start + print("training stage end, total time = ", total_time) + +def test_stage( + test_data, + exp_dir, + config, + compute_dtype, + use_ascend, + normalizer=None, + checkpoint=None +): + + model_config = config["model"] + down = config["down"] + size = config["size"] + resolution = size // down + deno = 1 / size * down + train_win = config["train_window"] + inferstep = config["inferstep"] + delta_t = config["delta_t"] + epochs = config["epochs"] + + model = RCNN( + **model_config, + compute_dtype=compute_dtype, + resolution=resolution, + deno=deno, + time_steps=train_win + ) + model_param = f"e{epochs}_k{model_config['modes']}_d{model_config['depths']}_l{model_config['hidden_channels']}_p{model_config['projection_channels']}" + result_save_dir = os.path.join(exp_dir, result_dir) + model_save_dir = os.path.join(result_save_dir, model_dir, model_param, model_name) + error_save_dir = os.path.join(result_save_dir, error_dir, model_param) + ensure_directories(error_save_dir) + + if checkpoint: + model_save_dir += f"_checkpoints_{checkpoint}" + if use_ascend: + auto_mixed_precision(model, 'O1') + + try: + ms.load_checkpoint(model_save_dir + '.ckpt', model) + print("Successfully loaded model") + except Exception as e: + print(f"Failed to load model: {e}") + exit() + + + model.set_train(False) + evl_error( + test_data[:, 0: inferstep], + model, + inferstep, + delta_t, + resolution, + error_save_dir, + compute_dtype, + normalizer, + ) + + + +def main_function( + experiment_dir, + specs_filename, + continue_from = None, + trainstage = True, + teststage = True, +): + burgers_config = json.load(open(os.path.join(experiment_dir, config_dir, specs_filename))) + use_ascend = context.get_context(attr_key='device_target') == "Ascend" + compute_dtype = mstype.float32 + + num = burgers_config["num_data"] + train_win = burgers_config["train_window"] + timesteps = burgers_config["timesteps"] + batch_size = burgers_config["batch_size"] + down = burgers_config["down"] + nolap = burgers_config["nolap"] + drop = burgers_config["drop"] + normalization = burgers_config["normalization"] + pretrain_iters = burgers_config["pretrain_iters"] + pretrain_window = burgers_config["pretrain_window"] + n_train = burgers_config["num_train"] + n_test = burgers_config["num_test"] + + def signal_handler(sig, frame): + print("Stopping early...") + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + + # [batch, ic, steps, uv, nx, ny] + all_data = get_data( + experiment_dir, + num, + down, + ) + + normalizer = None + if normalization: + normalizer = UnitGaussianNormalizer( + x = all_data, + normalization_dim = (0, 1, 3, 4) + ) + all_data = normalizer.encode(all_data) + + train_data = all_data[: n_train] + test_data = all_data[-n_test:] + + if trainstage: + print("*****************trainingStage****************") + + all_data_set = generate_dataset(train_data[:, :timesteps], train_win=train_win, icnum=n_train, nolap=nolap) + print("all data set shape: ", all_data_set.shape) + # [400, 5, 2, 26, 26] [ic, times, uv, nx, ny] + train_dataset = GeneratorDataset( + source = MyDataset(all_data_set), + column_names = ['data', 'labels'], + shuffle=False + ) + train_loader = train_dataset.batch(batch_size, drop) + + if pretrain_iters != 0 and pretrain_window != 0: + pretrain_data_set = generate_dataset(train_data[:, :timesteps], train_win=pretrain_window, icnum=n_train, nolap=nolap) + pretrain_dataset = GeneratorDataset( + source = MyDataset(pretrain_data_set), + column_names = ['data', 'labels'], + shuffle=False + ) + pretrain_loader = pretrain_dataset.batch(batch_size, drop) + else: + pretrain_loader=None + + + train_stage( + train_loader, + pretrain_loader, + experiment_dir, + burgers_config, + use_ascend, + continue_from, + compute_dtype, + ) + if teststage: + print("*****************testingStage****************") + test_stage( + test_data, + experiment_dir, + burgers_config, + compute_dtype, + use_ascend, + normalizer, + checkpoint=continue_from + ) + + +def parse_args(): + """parse input args""" + parser = argparse.ArgumentParser(description="burgers train") + parser.add_argument( + "--experiment", + type=str, + dest="experiment_directory", + required=True, + help="The experiment directory. " + ) + parser.add_argument( + "--config_filename", + type=str, + dest="config_filename", + default="burgers.json", + help="The filename of experiment specifications" + ) + parser.add_argument( + "--continue", + type=int, + dest="continue_from", + default=None, + help="A snapshot to continue from.", + ) + parser.add_argument( + "--mode", + type=str, + dest="mode", + default="PYNATIVE", + choices=["GRAPH", "PYNATIVE"], + help="Running in GRAPH_MODE OR PYNATIVE_MODE" + ) + parser.add_argument( + "--device_target", + type=str, + dest="device_target", + default="Ascend", + choices=["GPU", "Ascend"], + help="The target device to run, support 'Ascend', 'GPU'" + ) + parser.add_argument( + "--device_id", + type=int, + dest="device_id", + default=0, + help="ID of the target device" + ) + parser.add_argument( + "--train_stage", + type=bool, + dest="train_stage", + default=True, + help="Whether to run training stage" + ) + parser.add_argument( + "--train_stage", + type=bool, + dest="train_stage", + default=True, + help="Whether to run test stage" + ) + input_args = parser.parse_args() + return input_args + +if __name__ == '__main__': + args = parse_args() + ms.context.set_context( + mode=context.GRAPH_MODE if args.mode.upper().startswith("GRAPH") else context.PYNATIVE_MODE) + ms.set_device(args.device_target, args.device_id) + ms.set_recursion_limit(99999999) + main_function( + experiment_dir = args.experiment_directory, + specs_filename = args.config_filename, + continue_from = args.continue_from, + trainstage = args.train_stage, + teststage = args.test_stage, + )