diff --git a/MindElec/mindelec/vision/body.py b/MindElec/mindelec/vision/body.py new file mode 100644 index 0000000000000000000000000000000000000000..380f76e53b98bbd95a2e3b6ea61a915d4b6346d0 --- /dev/null +++ b/MindElec/mindelec/vision/body.py @@ -0,0 +1,101 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Visualization of the results 3D VTK form""" + +import os +from pyevtk.hl import gridToVTK +import numpy as np + + +def vtk_structure(grid_tensor, eh_tensor, path_res): + r""" + Generates 3D vtk file for visualizaiton. + + Args: + grid_tensor (np.array): grid data (shape: (dim_t, dim_x, dim_y, dim_z, 4)). + eh_tensor (np.array): electric and magnetic data (np.array, shape: (dim_t, dim_x, dim_y, dim_z, 6)). + path_res (str): save path for the output vtk file. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindelec.vision import vtk_structure + >>> grid_tensor = np.random.rand(20, 10, 10, 10, 4).astype(np.float32) + >>> eh_tensor = np.random.rand(20, 10, 10, 10, 6).astype(np.float32) + >>> path_res = './result_vtk' + >>> vtk_structure(grid_tensor, eh_tensor, path_res) + """ + if not isinstance(grid_tensor, np.ndarray): + raise TypeError("The type of grid_tensor should be numpy array, but get {}".format(type(grid_tensor))) + + if not isinstance(eh_tensor, np.ndarray): + raise TypeError("The type of eh_tensor should be numpy array, but get {}".format(type(eh_tensor))) + + if not isinstance(path_res, str): + raise TypeError("The type of path_res should be str, but get {}".format(type(path_res))) + if not os.path.exists(path_res): + os.makedirs(path_res) + + input_grid = grid_tensor + output_grid = eh_tensor + + shape_grid = input_grid.shape + shape_eh = output_grid.shape + + if len(shape_grid) != 5 or shape_grid[-1] != 4: + raise ValueError("grid_tensor shape should be (dim_t, dim_x, dim_y, dim_z, 4), but get {}" + .format(shape_grid)) + + if len(shape_eh) != 5 or shape_eh[-1] != 6: + raise ValueError("eh_tensor shape should be (dim_t, dim_x, dim_y, dim_z, 6), but get {}" + .format(shape_eh)) + + if shape_grid[:4] != shape_eh[:4]: + raise ValueError("grid_tensor and eh_tensor should have the same dimension except the last axis, " + "but get grid_tensor shape {} and eh_tensor shape{}".format(shape_grid, shape_eh)) + + (dim_t, dim_x, dim_y, dim_z, d) = input_grid.shape + input_grid = np.reshape(input_grid, (dim_t * dim_x * dim_y * dim_z, d)) + x_min, x_max = np.min(input_grid[:, 0]), np.max(input_grid[:, 0]) + y_min, y_max = np.min(input_grid[:, 1]), np.max(input_grid[:, 1]) + z_min, z_max = np.min(input_grid[:, 2]), np.max(input_grid[:, 2]) + + x_all = np.linspace(x_min, x_max, dim_x, endpoint=True, dtype='float64') + y_all = np.linspace(y_min, y_max, dim_y, endpoint=True, dtype='float64') + z_all = np.linspace(z_min, z_max, dim_z, endpoint=True, dtype='float64') + + x = np.zeros((dim_x, dim_y, dim_z)) + y = np.zeros((dim_x, dim_y, dim_z)) + z = np.zeros((dim_x, dim_y, dim_z)) + + for i in range(dim_x): + for j in range(dim_y): + for k in range(dim_z): + x[i, j, k] = x_all[i] + y[i, j, k] = y_all[j] + z[i, j, k] = z_all[k] + + for t in range(dim_t): + print(t) + output_grid_show = output_grid[t] + ex, ey, ez = output_grid_show[:, :, :, 0], output_grid_show[:, :, :, 1], output_grid_show[:, :, :, 2] + hx, hy, hz = output_grid_show[:, :, :, 3], output_grid_show[:, :, :, 4], output_grid_show[:, :, :, 5] + ex, ey, ez = ex.astype(np.float64), ey.astype(np.float64), ez.astype(np.float64) + hx, hy, hz = hx.astype(np.float64), hy.astype(np.float64), hz.astype(np.float64) + gridToVTK(os.path.join(path_res, 'eh_t' + str(t)), + x, y, z, + pointData={"Ex": ex, "Ey": ey, "Ez": ez, "Hx": hx, "Hy": hy, "Hz": hz}) diff --git a/MindElec/mindelec/vision/mindinsight_vision.py b/MindElec/mindelec/vision/mindinsight_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7a447efdfa53b7763069fef6926855de9aacb5 --- /dev/null +++ b/MindElec/mindelec/vision/mindinsight_vision.py @@ -0,0 +1,221 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Callback functions for model.train and model.eval""" + +import os +import numpy as np +from mindspore.dataset.engine.datasets import BatchDataset as ds +from mindspore.common.tensor import Tensor +from mindspore.train.callback import Callback +from mindspore.train.summary import SummaryRecord +from ..solver import Solver + + +class MonitorTrain(Callback): + r""" + Loss monitor for train. + + Args: + per_print_times (int): print loss interval. + summary_dir (str): summary save path. + + Returns: + Callback monitor. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindelec.vision import MonitorTrain + >>> per_print_times = 1 + >>> summary_dir = './summary_train' + >>> MonitorTrain(per_print_times, summary_dir) + """ + + def __init__(self, per_print_times=1, summary_dir='./summary_train'): + super(MonitorTrain, self).__init__() + if not isinstance(per_print_times, int): + raise TypeError("per_print_times must be int, but get {}".format(type(per_print_times))) + if per_print_times <= 0: + raise ValueError("per_print_times must be > 0.") + + if not isinstance(summary_dir, str): + raise TypeError("summary_dir must be str, but get {}".format(type(summary_dir))) + if not os.path.exists(summary_dir): + os.makedirs(summary_dir) + + self._per_print_times = per_print_times + self._summary_dir = summary_dir + self._step_counter = 0 + self.final_loss = 0 + + def __enter__(self): + self.summary_record = SummaryRecord(self._summary_dir) + return self + + def __exit__(self, *exc_args): + self.summary_record.close() + + def step_end(self, run_context): + """ + Evaluate the model at the end of epoch. + + Args: + run_context (RunContext): Context of the train running. + """ + self._step_counter += 1 + params = run_context.original_args() + + loss = params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0].asnumpy(), np.ndarray) and isinstance(loss[0], Tensor): + loss = loss[0] + + if isinstance(loss.asnumpy(), np.ndarray) and isinstance(loss, Tensor): + loss = np.mean(loss.asnumpy()) + + cur_step = (params.cur_step_num - 1) % params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, training end.".format( + params.cur_epoch_num, cur_step)) + + if self._per_print_times != 0 and params.cur_step_num % self._per_print_times == 0: + print("epoch: %s step: %s, loss is %s" % (params.cur_epoch_num, cur_step, loss), flush=True) + self.summary_record.add_value('scalar', 'train_loss', Tensor(loss)) + self.summary_record.record(self._step_counter) + self.final_loss = loss + + +class MonitorEval(Callback): + r""" + LossMonitor for eval. + + Args: + summary_dir (str): summary save path. + model (Solver): Model object for eval. + eval_ds (Dataset): eval dataset. + eval_interval (int): eval interval. + draw_flag (bool): specifies if save summary_record. + + Returns: + Callback monitor. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore.nn as nn + >>> from mindelec.solver import Solver + >>> from mindelec.vision import MonitorEval + >>> class S11Predictor(nn.Cell): + ... def __init__(self, input_dimension): + ... super(S11Predictor, self).__init__() + ... self.fc1 = nn.Dense(input_dimension, 128) + ... self.fc2 = nn.Dense(128, 128) + ... self.fc3 = nn.Dense(128, 1001) + ... self.relu = nn.ReLU() + ... + ... def construct(self, x): + ... x0 = x + ... x1 = self.relu(self.fc1(x0)) + ... x2 = self.relu(self.fc2(x1)) + ... x = self.fc3(x1 + x2) + ... return x + >>> model_net = S11Predictor(3) + >>> model = Solver(network=model_net, mode="Data", optimizer=nn.Adam(0.001), loss_fn=nn.MSELoss()) + >>> # For details about how to build the dataset, please refer to the tutorial + >>> # document on the official website. + >>> eval_ds = Dataset() + >>> summary_dir = './summary_eval_path' + >>> eval_interval = 10 + >>> draw_flag = True + >>> MonitorEval(summary_dir, model, eval_ds, eval_interval, draw_flag) + """ + + def __init__(self, + summary_dir='./summary_eval', + model=None, + eval_ds=None, + eval_interval=10, + draw_flag=True): + super(MonitorEval, self).__init__() + if not isinstance(summary_dir, str): + raise ValueError("summary_dir must be str, but get {}".format(type(summary_dir))) + + if not isinstance(model, Solver): + raise ValueError("model must be mindelec solver, but get {}".format(type(model))) + + if not isinstance(eval_ds, ds): + raise ValueError("eval dataset must be mindelec dataset, but get {}".format(type(eval_ds))) + + if not isinstance(eval_interval, int): + raise TypeError("eval_interval must be int, but get {}".format(type(eval_interval))) + if eval_interval <= 0: + raise ValueError("eval_interval must be > 0.") + + if not isinstance(draw_flag, bool): + raise ValueError("draw_flag must be bool, but get {}".format(type(draw_flag))) + + self._summary_dir = summary_dir + self._model = model + self._eval_ds = eval_ds + self._eval_interval = eval_interval + self._draw_flag = draw_flag + + self._eval_count = 0 + self.temp = None + self.loss_final = 0.0 + self.l2_s11_final = 0.0 + + def __enter__(self): + self.summary_record = SummaryRecord(self._summary_dir) + return self + + def __exit__(self, *exc_args): + self.summary_record.close() + + def epoch(self, run_context): + """ + Evaluate the model at the end of epoch. + + Args: + run_context (RunContext): Context of the train running. + """ + self.temp = run_context + self._eval_count += 1 + cb_param = run_context.original_args() + cur_epoch = cb_param.cur_epoch_num + if cur_epoch % self._eval_interval == 0: + res_eval = self._model.model.eval(valid_dataset=self._eval_ds, dataset_sink_mode=True) + loss_eval_print, l2_s11_print = res_eval['eval_mrc']['loss_error'], res_eval['eval_mrc']['l2_error'] + + self.loss_final = loss_eval_print + self.l2_s11_final = l2_s11_print + print('Eval current epoch:', cur_epoch, ' loss:', loss_eval_print, ' l2_s11:', l2_s11_print) + + self.summary_record.add_value('scalar', 'eval_loss', Tensor(loss_eval_print)) + self.summary_record.record(self._eval_count * self._eval_interval) + + self.summary_record.add_value('scalar', 'l2_s11', Tensor(l2_s11_print)) + self.summary_record.record(self._eval_count * self._eval_interval) + + if self._draw_flag: + pic_res = res_eval['eval_mrc']['pic_res'] + for i in range(len(pic_res)): + self.summary_record.add_value('image', 'l2_s11_image_' + str(i), + Tensor(np.expand_dims(pic_res[i], 0).transpose((0, 3, 1, 2)))) + self.summary_record.record(self._eval_count * self._eval_interval) diff --git a/MindElec/mindelec/vision/plane.py b/MindElec/mindelec/vision/plane.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd83164d322c47bf5e4f56ee1ebd11c8a2d0e64 --- /dev/null +++ b/MindElec/mindelec/vision/plane.py @@ -0,0 +1,138 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Visualization of the results in 2D image form""" + +import os +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +matplotlib.use('Agg') + + +def plot_s11(s11_tensor, path_image_save, legend, dpi=300): + r""" + Draw s11-frequency curve and save it in path_image_save. + + Args: + s11_tensor (np.array): s11 data (shape: (dim_frequency, 2)). + path_image_save (str): s11-frequency curve saved path. + legend (str): the legend of s11, plotting parameters. + dpi (int): plotting parameters. Default: 300. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindelec.vision import plot_s11 + >>> s11 = np.random.rand(1001, 2).astype(np.float32) + >>> s11[:, 0] = np.linspace(0, 4 * 10 ** 9, 1001) + >>> s11 = s11.astype(np.float32) + >>> s11_tensor = s11 + >>> path_image_save = './result_s11' + >>> legend = 's11' + >>> dpi = 300 + >>> plot_s11(s11_tensor, path_image_save, legend, dpi) + """ + if not isinstance(s11_tensor, np.ndarray): + raise TypeError("The type of s11_tensor should be numpy array, but get {}".format(type(s11_tensor))) + + if not isinstance(path_image_save, str): + raise TypeError("The type of path_image_save should be str, but get {}".format(type(path_image_save))) + if not os.path.exists(path_image_save): + os.makedirs(path_image_save) + + if not isinstance(legend, str): + raise TypeError("The type of legend should be str, but get {}".format(type(legend))) + + if not isinstance(dpi, int): + raise TypeError("The type of dpi must be int, but get {}".format(type(dpi))) + if dpi <= 0: + raise ValueError("dpi must be > 0.") + + shape_s11_all = s11_tensor.shape + if len(shape_s11_all) != 2 or shape_s11_all[-1] != 2: + raise ValueError("s11_tensor shape should be (dim_frequency, 2), but get {}".format(shape_s11_all)) + + s11_temp, frequency = s11_tensor[:, 0], s11_tensor[:, 1] + plt.figure(dpi=dpi, figsize=(8, 4)) + plt.plot(frequency, s11_temp, '-', label=legend, linewidth=2) + plt.title('s11(dB)') + plt.xlabel('frequency(Hz)') + plt.ylabel('dB') + plt.legend() + plt.savefig(os.path.join(path_image_save, 's11.jpg')) + plt.close() + + +def plot_eh(simu_res_tensor, path_image_save, z_index, dpi=300): + r""" + Draw electric and magnetic field values of every timestep for 2D slices, and save them in path_image_save + + Args: + simu_res_tensor (np.array): simulation result (shape (dim_t, dim_x, dim_y, dim_z, 6)). + path_image_save (str): images saved path. + z_index (int): show 2D image for z=z_index. + dpi (int): plotting parameters. Default: 300. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindelec.vision import plot_eh + >>> simu_res_tensor = np.random.rand(20, 10, 10, 10, 6).astype(np.float32) + >>> path_image_save = './result_eh' + >>> z_index = 5 + >>> dpi = 300 + >>> plot_eh(simu_res_tensor, path_image_save, z_index, dpi) + """ + if not isinstance(simu_res_tensor, np.ndarray): + raise TypeError("The type of simu_res_tensor should be numpy array, but get {}".format(type(simu_res_tensor))) + + if not isinstance(path_image_save, str): + raise TypeError("The type of path_image_save should be str, but get {}".format(type(path_image_save))) + if not os.path.exists(path_image_save): + os.makedirs(path_image_save) + + if not isinstance(z_index, int): + raise TypeError("The type of z_index must be int, but get {}".format(type(z_index))) + if z_index <= 0: + raise ValueError("z_index must be > 0.") + + if not isinstance(dpi, int): + raise TypeError("The type of dpi must be int, but get {}".format(type(dpi))) + if dpi <= 0: + raise ValueError("dpi must be > 0.") + + shape_simulation_res = simu_res_tensor.shape + if len(shape_simulation_res) != 5 or shape_simulation_res[-1] != 6: + raise ValueError("path_simu_res shape should be (dim_t, dim_x, dim_y, dim_z, 6), but get {}" + .format(shape_simulation_res)) + + plot_order = ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz'] + + for i in range(6): + current = simu_res_tensor[:, :, :, z_index, i] + min_val, max_val = np.min(current), np.max(current) + timesteps = len(current) + for t in range(timesteps): + current_2d = current[t] + plt.figure(dpi=dpi) + plt.imshow(current_2d, vmin=min_val, vmax=max_val) + plt.colorbar() + plt.savefig(os.path.join(path_image_save, plot_order[i] + '_' + str(t) + '.jpg')) + plt.close() diff --git a/MindElec/mindelec/vision/print_scatter.py b/MindElec/mindelec/vision/print_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..dfce1a46dd909b40b8ae13dc09793ba4f31f5c23 --- /dev/null +++ b/MindElec/mindelec/vision/print_scatter.py @@ -0,0 +1,118 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""util functions for tests""" + +import os +import numpy as np +import matplotlib.pyplot as plt + + +def print_graph_1d(name, x, path, clear=True): + r""" + Draw 1d scatter image + + Args: + name (str): name of the graph. + x (np.array): data to draw (shape (dim_print,)). + path (str): save path of the graph. + clear (bool): specifies whether clear the current axes. Default: True. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindelec.vision import print_graph_1d + >>> name = "output.jpg" + >>> x = np.ones(10) + >>> path = "./graph_1d" + >>> clear = True + >>> print_graph_1d(name, x, path, clear) + """ + if not isinstance(name, str): + raise TypeError("The type of name should be str, but get {}".format(type(name))) + + if not isinstance(x, np.ndarray): + raise TypeError("The type of x should be numpy array, but get {}".format(type(x))) + shape_x = x.shape + if len(shape_x) != 1: + raise ValueError("x shape should be (dim_print,), but get {}".format(shape_x)) + + if not isinstance(path, str): + raise TypeError("The type of path should be str, but get {}".format(type(path))) + if not os.path.exists(path): + os.makedirs(path) + + if not isinstance(clear, bool): + raise TypeError("The type of clear should be bool, but get {}".format(type(clear))) + + if clear: + plt.cla() + y = np.zeros(x.shape) + plt.scatter(x, y, alpha=0.8, s=0.8) + plt.savefig(os.path.join(path, name), dpi=600) + + +def print_graph_2d(name, x, y, path, clear=True): + r""" + Draw 2d scatter image + + Args: + name (str): name of the graph. + x (np.array): data x to draw (shape (dim_print,)). + y (np.array): data y to draw (shape (dim_print,)). + path (str): save path of the graph. + clear (bool): specifies whether clear the current axes. Default: True. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindelec.vision import print_graph_2d + >>> name = "output.jpg" + >>> x = np.ones(10) + >>> y = np.ones(10) + >>> path = "./graph_2d" + >>> clear = True + >>> print_graph_2d(name, x, y, path, clear) + """ + if not isinstance(name, str): + raise TypeError("The type of name should be str, but get {}".format(type(name))) + + if not isinstance(x, np.ndarray): + raise TypeError("The type of x should be numpy array, but get {}".format(type(x))) + shape_x = x.shape + if len(shape_x) != 1: + raise ValueError("x shape should be (dim_print,), but get {}".format(shape_x)) + + if not isinstance(y, np.ndarray): + raise TypeError("The type of y should be numpy array, but get {}".format(type(y))) + shape_y = y.shape + if len(shape_y) != 1: + raise ValueError("y shape should be (dim_print,), but get {}".format(shape_y)) + + if not isinstance(path, str): + raise TypeError("The type of path should be str, but get {}".format(type(path))) + if not os.path.exists(path): + os.makedirs(path) + + if not isinstance(clear, bool): + raise TypeError("The type of clear should be bool, but get {}".format(type(clear))) + + if clear: + plt.cla() + plt.scatter(x, y, alpha=1.0, s=0.8) + plt.savefig(os.path.join(path, name), dpi=600) diff --git a/MindElec/mindelec/vision/video.py b/MindElec/mindelec/vision/video.py new file mode 100644 index 0000000000000000000000000000000000000000..d739d6f27667064218614d96584692d9fd82680a --- /dev/null +++ b/MindElec/mindelec/vision/video.py @@ -0,0 +1,87 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Visualization of the results in video form""" + +import os +from importlib import import_module +from PIL import Image + + +def image_to_video(path_image, path_video, video_name, fps): + r""" + Create video from existing images. + + Args: + path_image (str): image path, all images are jpg. + image names in path_image should be like: + 00.jpg, 01.jpg, 02.jpg, ... 09.jpg, 10.jpg, 11.jpg, 12.jpg ... + path_video (str): video path, video saved path. + video_name (str): video name(.avi file) + fps (int): Specifies how many pictures per second in video. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindelec.vision import plot_eh, image_to_video + >>> path_image = './images' + >>> eh = np.random.rand(5, 10, 10, 10, 6).astype(np.float32) + >>> plot_eh(eh, path_image, 5, 300) + >>> path_video = './result_video' + >>> video_name = 'video.avi' + >>> fps = 10 + >>> image_to_video(path_image, path_video, video_name, fps) + """ + if not isinstance(path_image, str): + raise TypeError("The type of path_image should be str, but get {}".format(type(path_image))) + if not os.path.exists(path_image): + raise Exception("No folder of images found in path_image, please check the path") + + if not isinstance(path_video, str): + raise TypeError("The type of path_video should be str, but get {}".format(type(path_video))) + if not os.path.exists(path_video): + os.makedirs(path_video) + + if not isinstance(video_name, str): + raise TypeError("The type of video_name should be str, but get {}".format(type(video_name))) + if '.avi' not in video_name or len(video_name) <= 4: + raise Exception("video_name should be .avi file, like result.avi, please check the video_name") + if video_name[-4:] != '.avi': + raise Exception("video_name should be .avi file, like result.avi, please check the video_name") + + if not isinstance(fps, int): + raise TypeError("The type of fps must be int, but get {}".format(type(fps))) + if fps <= 0: + raise ValueError("fps must be > 0.") + + cv2 = import_module("cv2") + fourcc = cv2.VideoWriter_fourcc(*"MJPG") + + images = os.listdir(path_image) + images.sort() + image = Image.open(os.path.join(path_image, images[0])) + vw = cv2.VideoWriter(os.path.join(path_video, video_name), fourcc, fps, image.size) + + for i in range(len(images)): + print(float(i) / len(images)) + jpgfile = os.path.join(path_image, images[i]) + try: + new_frame = cv2.imread(jpgfile) + vw.write(new_frame) + except IOError as exc: + print(jpgfile, exc) + vw.release() + print('Video save success!') diff --git a/README.md b/README.md index 973296add64523a2f6be0e6f726d6e5e5b042d5e..fdda4c56b6057da85affd8760e20a01e9dd6f2f6 100644 --- a/README.md +++ b/README.md @@ -1 +1,10 @@ -MindSpore for science. +# MindScience + +## 概述 + +MindScience是基于MindSpore融合架构打造的科学计算行业套件,包含了业界领先的数据集、基础模型、预制高精度模型和前后处理工具,加速了科学行业应用开发。目前已推出面向电子信息行业的MindElec套件和面向生命科学行业的MindSPONGE套件,分别实现了电磁仿真性能提升10倍和生物制药化合物模拟效率提升50%。 + +## 架构图 +
+