diff --git a/MindElec/examples/data_driven/parameterization/dataset/Butterfly_antenna/data_input.npy b/MindElec/examples/data_driven/parameterization/dataset/Butterfly_antenna/data_input.npy new file mode 100755 index 0000000000000000000000000000000000000000..01b3618113f9e70a9b1932ecb1ecb5e28616414f Binary files /dev/null and b/MindElec/examples/data_driven/parameterization/dataset/Butterfly_antenna/data_input.npy differ diff --git a/MindElec/examples/data_driven/parameterization/dataset/Butterfly_antenna/data_label.npy b/MindElec/examples/data_driven/parameterization/dataset/Butterfly_antenna/data_label.npy new file mode 100755 index 0000000000000000000000000000000000000000..c2832f0fab62be6edf9f7599515a5266773a53a1 Binary files /dev/null and b/MindElec/examples/data_driven/parameterization/dataset/Butterfly_antenna/data_label.npy differ diff --git a/MindElec/examples/data_driven/parameterization/dataset/Phone/data_input.npy b/MindElec/examples/data_driven/parameterization/dataset/Phone/data_input.npy new file mode 100755 index 0000000000000000000000000000000000000000..c2949678bed3c614f36b8a63df441b007d49a516 Binary files /dev/null and b/MindElec/examples/data_driven/parameterization/dataset/Phone/data_input.npy differ diff --git a/MindElec/examples/data_driven/parameterization/dataset/Phone/data_label.npy b/MindElec/examples/data_driven/parameterization/dataset/Phone/data_label.npy new file mode 100755 index 0000000000000000000000000000000000000000..571576488614352e8d414956424bd16056d237ca Binary files /dev/null and b/MindElec/examples/data_driven/parameterization/dataset/Phone/data_label.npy differ diff --git a/MindElec/examples/data_driven/parameterization/docs/network_architecture.JPG b/MindElec/examples/data_driven/parameterization/docs/network_architecture.JPG new file mode 100755 index 0000000000000000000000000000000000000000..904666c3d42d821e9843f389cb3433868b6dfaa6 Binary files /dev/null and b/MindElec/examples/data_driven/parameterization/docs/network_architecture.JPG differ diff --git a/MindElec/examples/data_driven/parameterization/eval.py b/MindElec/examples/data_driven/parameterization/eval.py new file mode 100755 index 0000000000000000000000000000000000000000..0d0508b178d3fb896bdfac89076ac41872f9c878 --- /dev/null +++ b/MindElec/examples/data_driven/parameterization/eval.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +eval +""" + +import os +import argparse +import numpy as np +import mindspore.nn as nn +from mindspore.common import set_seed +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from mindelec.solver import Solver + +from src.dataset import create_dataset +from src.maxwell_model import S11Predictor +from src.loss import EvalMetric + +set_seed(123456) +np.random.seed(123456) + +print("pid:", os.getpid()) +parser = argparse.ArgumentParser(description='Parametrization S11 Simulation') +parser.add_argument('--batch_size', type=int, default=8) +parser.add_argument('--input_dim', type=int, default=3) +parser.add_argument('--device_num', type=int, default=1) +parser.add_argument('--device_target', type=str, default="Ascend") +parser.add_argument('--checkpoint_dir', default='./ckpt/', help='checkpoint directory') +parser.add_argument('--input_path', default='./dataset/Butterfly_antenna/data_input.npy') +parser.add_argument('--label_path', default='./dataset/Butterfly_antenna/data_label.npy') +opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, + save_graphs=False, + device_target=opt.device_target, + device_id=opt.device_num) + + +def eval_s11(): + """evaluate s11""" + data, config_data = create_dataset(opt) + + model_net = S11Predictor(opt.input_dim) + model_net.to_float(mstype.float16) + + param_dict = load_checkpoint(os.path.join(opt.checkpoint_dir, 'model.ckpt')) + load_param_into_net(model_net, param_dict) + + eval_error_mrc = EvalMetric(scale_s11=config_data["scale_S11"], + length=data["eval_data_length"], + frequency=data["frequency"], + show_pic_number=4, + file_path='./eval_result') + + solver = Solver(network=model_net, + mode="Data", + optimizer=nn.Adam(model_net.trainable_params(), 0.001), + metrics={'eval_mrc': eval_error_mrc}, + loss_fn=nn.MSELoss()) + + res_eval = solver.model.eval(valid_dataset=data["eval_loader"], dataset_sink_mode=True) + + loss_mse, l2_s11 = res_eval["eval_mrc"]["loss_error"], res_eval["eval_mrc"]["l2_error"] + print(f'Loss_mse: {loss_mse:.10f} ', + f'L2_S11: {l2_s11:.10f}') + + +if __name__ == '__main__': + eval_s11() diff --git a/MindElec/examples/data_driven/parameterization/src/dataset.py b/MindElec/examples/data_driven/parameterization/src/dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..00f41cd06a8d796728fe36a1c882ba6f01e68137 --- /dev/null +++ b/MindElec/examples/data_driven/parameterization/src/dataset.py @@ -0,0 +1,130 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +dataset +""" +import os +import shutil +import numpy as np +from mindelec.data import Dataset, ExistedDataConfig + + +def custom_normalize(data): + """ + get normalize data + """ + print("Custom normalization is called") + ori_shape = data.shape + data = data.reshape(ori_shape[0], -1) + data = np.transpose(data) + mean = np.mean(data, axis=1) + data = data - mean[:, None] + std = np.std(data, axis=1) + std += (np.abs(std) < 0.0000001) + data = data / std[:, None] + data = np.transpose(data) + data = data.reshape(ori_shape) + return data + + +def create_dataset(opt): + """ + load data + """ + data_input_path = opt.input_path + data_label_path = opt.label_path + + data_input = np.load(data_input_path) + data_label = np.load(data_label_path) + + frequency = data_label[0, :, 0] + data_label = data_label[:, :, 1] + + print(data_input.shape) + print(data_label.shape) + print("data load finish") + + data_input = custom_normalize(data_input) + + config_data_prepare = {} + + config_data_prepare["scale_input"] = 0.5 * np.max(np.abs(data_input), axis=0) + config_data_prepare["scale_S11"] = 0.5 * np.max(np.abs(data_label)) + + data_input[:, :] = data_input[:, :] / config_data_prepare["scale_input"] + data_label[:, :] = data_label[:, :] / config_data_prepare["scale_S11"] + + permutation = np.random.permutation(data_input.shape[0]) + data_input = data_input[permutation] + data_label = data_label[permutation] + + length = data_input.shape[0] // 10 + train_input, train_label = data_input[length:], data_label[length:] + eval_input, eval_label = data_input[:length], data_label[:length] + + print(np.shape(train_input)) + print(np.shape(train_label)) + print(np.shape(eval_input)) + print(np.shape(eval_label)) + + if not os.path.exists('./data_prepare'): + os.mkdir('./data_prepare') + else: + shutil.rmtree('./data_prepare') + os.mkdir('./data_prepare') + + train_input = train_input.astype(np.float32) + np.save('./data_prepare/train_input', train_input) + train_label = train_label.astype(np.float32) + np.save('./data_prepare/train_label', train_label) + eval_input = eval_input.astype(np.float32) + np.save('./data_prepare/eval_input', eval_input) + eval_label = eval_label.astype(np.float32) + np.save('./data_prepare/eval_label', eval_label) + + electromagnetic_train = ExistedDataConfig(name="electromagnetic_train", + data_dir=['./data_prepare/train_input.npy', + './data_prepare/train_label.npy'], + columns_list=["inputs", "label"], + data_format="npy") + electromagnetic_eval = ExistedDataConfig(name="electromagnetic_eval", + data_dir=['./data_prepare/eval_input.npy', + './data_prepare/eval_label.npy'], + columns_list=["inputs", "label"], + data_format="npy") + train_batch_size = opt.batch_size + eval_batch_size = len(eval_input) + + train_dataset = Dataset(existed_data_list=[electromagnetic_train]) + train_loader = train_dataset.create_dataset(batch_size=train_batch_size, shuffle=True) + + eval_dataset = Dataset(existed_data_list=[electromagnetic_eval]) + eval_loader = eval_dataset.create_dataset(batch_size=eval_batch_size, shuffle=False) + + data = { + "train_loader": train_loader, + "eval_loader": eval_loader, + + "train_data": train_input, + "train_label": train_label, + "eval_data": eval_input, + "eval_label": eval_label, + + "train_data_length": len(train_label), + "eval_data_length": len(eval_label), + "frequency": frequency, + } + + return data, config_data_prepare diff --git a/MindElec/examples/data_driven/parameterization/src/loss.py b/MindElec/examples/data_driven/parameterization/src/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..97b7bdf8a72f0be771d431d68ead3f4610667d1d --- /dev/null +++ b/MindElec/examples/data_driven/parameterization/src/loss.py @@ -0,0 +1,98 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +loss +""" + +import os +import shutil +import mindspore.nn as nn +import matplotlib.pyplot as plt +import numpy as np +import cv2 + + +class EvalMetric(nn.Metric): + """ + eval metric + """ + + def __init__(self, scale_s11, length, frequency, show_pic_number, file_path): + super(EvalMetric, self).__init__() + self.clear() + self.scale_s11 = scale_s11 + self.length = length + self.frequency = frequency + self.show_pic_number = show_pic_number + self.file_path = file_path + self.show_pic_id = np.random.choice(length, self.show_pic_number, replace=False) + + def clear(self): + """ + clear error + """ + self.error_sum_l2_error = 0 + self.error_sum_loss_error = 0 + self.pic_res = None + + def update(self, *inputs): + """ + update error + """ + if not os.path.exists(self.file_path): + os.mkdir(self.file_path) + else: + shutil.rmtree(self.file_path) + os.mkdir(self.file_path) + + y_pred = self._convert_data(inputs[0]) + y_label = self._convert_data(inputs[1]) + + test_predict, test_label = y_pred, y_label + test_predict[:, :] = test_predict[:, :] * self.scale_s11 + test_label[:, :] = test_label[:, :] * self.scale_s11 + self.pic_res = [] + + for i in range(len(test_label)): + predict_real_temp = test_predict[i] + label_real_temp = test_label[i] + l2_error_temp = np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) / \ + np.sqrt(np.sum(np.square(label_real_temp))) + self.error_sum_l2_error += l2_error_temp + self.error_sum_loss_error += np.mean((label_real_temp - predict_real_temp) ** 2) + + s11_label, s11_predict = label_real_temp, predict_real_temp + plt.figure(dpi=250) + plt.plot(self.frequency, s11_predict, '-', label='AI Model', linewidth=2) + plt.plot(self.frequency, s11_label, '--', label='CST', linewidth=1) + plt.title('s11(dB)') + plt.xlabel('frequency(GHz) l2_s11:' + str(l2_error_temp)[:10]) + plt.ylabel('dB') + plt.legend() + plt.savefig(self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg') + plt.close() + if i in self.show_pic_id: + self.pic_res.append(cv2.imread( + self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg')) + + self.pic_res = np.array(self.pic_res).astype(np.float32) + + def eval(self): + """ + compute final error + """ + return {'l2_error': self.error_sum_l2_error / self.length, + 'loss_error': self.error_sum_loss_error / self.length, + 'pic_res': self.pic_res} diff --git a/MindElec/examples/data_driven/parameterization/src/maxwell_model.py b/MindElec/examples/data_driven/parameterization/src/maxwell_model.py new file mode 100755 index 0000000000000000000000000000000000000000..5b5aecf9c3042d3b48768c9d0b90250b515b5346 --- /dev/null +++ b/MindElec/examples/data_driven/parameterization/src/maxwell_model.py @@ -0,0 +1,47 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +maxwell S11 model +""" + +import mindspore.nn as nn + + +class S11Predictor(nn.Cell): + """ + maxwell S11 model define + """ + def __init__(self, input_dimension): + super(S11Predictor, self).__init__() + self.fc1 = nn.Dense(input_dimension, 128) + self.fc2 = nn.Dense(128, 128) + self.fc3 = nn.Dense(128, 128) + self.fc4 = nn.Dense(128, 128) + self.fc5 = nn.Dense(128, 128) + self.fc6 = nn.Dense(128, 128) + self.fc7 = nn.Dense(128, 1001) + self.relu = nn.ReLU() + + def construct(self, x): + """forward""" + x0 = x + x1 = self.relu(self.fc1(x0)) + x2 = self.relu(self.fc2(x1)) + x3 = self.relu(self.fc3(x1 + x2)) + x4 = self.relu(self.fc4(x1 + x2 + x3)) + x5 = self.relu(self.fc5(x1 + x2 + x3 + x4)) + x6 = self.relu(self.fc6(x1 + x2 + x3 + x4 + x5)) + x = self.fc7(x1 + x2 + x3 + x4 + x5 + x6) + return x diff --git a/MindElec/examples/data_driven/parameterization/train.py b/MindElec/examples/data_driven/parameterization/train.py new file mode 100755 index 0000000000000000000000000000000000000000..ed6dc7a484d439bc1d8a9acf9c3ff373aa8f7852 --- /dev/null +++ b/MindElec/examples/data_driven/parameterization/train.py @@ -0,0 +1,124 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +train +""" + +import os +import argparse +import numpy as np +import mindspore.nn as nn +from mindspore.common import set_seed +import mindspore.common.dtype as mstype +from mindspore import context, save_checkpoint +from mindspore.train.callback import TimeMonitor + +from mindelec.solver import Solver +from mindelec.vision import MonitorTrain, MonitorEval + +from src.dataset import create_dataset +from src.maxwell_model import S11Predictor +from src.loss import EvalMetric + +set_seed(123456) +np.random.seed(123456) + +parser = argparse.ArgumentParser(description='Parametrization S11 Simulation') +parser.add_argument('--epochs', type=int, default=10000) +parser.add_argument('--print_interval', type=int, default=1000) +parser.add_argument('--batch_size', type=int, default=8) +parser.add_argument('--lr', type=float, default=0.0001) +parser.add_argument('--input_dim', type=int, default=3) +parser.add_argument('--device_num', type=int, default=1) +parser.add_argument('--device_target', type=str, default="Ascend") +parser.add_argument('--checkpoint_dir', default='./ckpt/', help='checkpoint directory') +parser.add_argument('--save_graphs_path', default='./graph_result/', help='checkpoint directory') +parser.add_argument('--input_path', default='./dataset/Butterfly_antenna/data_input.npy') +parser.add_argument('--label_path', default='./dataset/Butterfly_antenna/data_label.npy') +opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, + save_graphs=True, + save_graphs_path=opt.save_graphs_path, + device_target=opt.device_target, + device_id=opt.device_num) + + +def get_lr(data): + """get learning rate""" + num_milestones = 10 + if data['train_data_length'] % opt.batch_size == 0: + iter_number = int(data['train_data_length'] / opt.batch_size) + else: + iter_number = int(data['train_data_length'] / opt.batch_size) + 1 + iter_number = opt.epochs * iter_number + milestones = [int(iter_number * i / num_milestones) for i in range(1, num_milestones)] + milestones.append(iter_number) + learning_rates = [opt.lr * 0.5 ** i for i in range(0, num_milestones - 1)] + learning_rates.append(opt.lr * 0.5 ** (num_milestones - 1)) + return milestones, learning_rates + + +def train(): + """train model""" + data, config_data = create_dataset(opt) + + print("scale_input: ", config_data["scale_input"]) + print("scale_s11: ", config_data["scale_S11"]) + + model_net = S11Predictor(opt.input_dim) + model_net.to_float(mstype.float16) + + milestones, learning_rates = get_lr(data) + + optim = nn.Adam(model_net.trainable_params(), + learning_rate=nn.piecewise_constant_lr(milestones, learning_rates)) + + eval_error_mrc = EvalMetric(scale_s11=config_data["scale_S11"], + length=data["eval_data_length"], + frequency=data["frequency"], + show_pic_number=4, + file_path='./eval_res') + + solver = Solver(network=model_net, + mode="Data", + optimizer=optim, + metrics={'eval_mrc': eval_error_mrc}, + loss_fn=nn.MSELoss()) + + monitor_train = MonitorTrain(per_print_times=1, + summary_dir='./summary_dir_train') + + monitor_eval = MonitorEval(summary_dir='./summary_dir_eval', + model=solver, + eval_ds=data["eval_loader"], + eval_interval=opt.print_interval, + draw_flag=True) + + time_monitor = TimeMonitor() + callbacks_train = [monitor_train, time_monitor, monitor_eval] + + solver.model.train(epoch=opt.epochs, + train_dataset=data["train_loader"], + callbacks=callbacks_train, + dataset_sink_mode=True) + + if not os.path.exists(opt.checkpoint_dir): + os.mkdir(opt.checkpoint_dir) + save_checkpoint(model_net, os.path.join(opt.checkpoint_dir, 'model.ckpt')) + + +if __name__ == '__main__': + train()