221 Star 941 Fork 692

GVPMindSpore/mindscience

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 4.57 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train process"""
import argparse
import os
import time
import numpy as np
from mindspore import nn, ops, set_seed, jit, context
import mindspore as ms
from mindflow.cell import MultiScaleFCSequential
from mindflow.utils import load_yaml_config
from src import create_training_dataset, create_test_dataset, calculate_l2_error, visual, Poisson2D
set_seed(123456)
np.random.seed(123456)
parser = argparse.ArgumentParser(description="poisson2D on a ring")
parser.add_argument("--mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"],
help="Running in GRAPH_MODE OR PYNATIVE_MODE")
parser.add_argument("--save_graphs", type=bool, default=False, choices=[True, False],
help="Whether to save intermediate compilation graphs")
parser.add_argument("--save_graphs_path", type=str, default="./graphs")
parser.add_argument("--device_target", type=str, default="GPU", choices=["GPU", "Ascend"],
help="The target device to run, support 'Ascend', 'GPU'")
parser.add_argument("--device_id", type=int, default=0, help="ID of the target device")
parser.add_argument("--config_file_path", type=str, default="./cylinder_flow.yaml")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE if args.mode.upper().startswith("GRAPH") else context.PYNATIVE_MODE,
save_graphs=args.save_graphs,
save_graphs_path=args.save_graphs_path,
device_target=args.device_target,
device_id=args.device_id)
print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
def train():
'''train and evaluate the network'''
# load configurations
config = load_yaml_config('poisson2d_cfg.yaml')
# create training dataset
dataset = create_training_dataset(config)
train_dataset = dataset.batch(batch_size=config["train_batch_size"])
# create test dataset
inputs, label = create_test_dataset(config)
# define models and optimizers
model = MultiScaleFCSequential(in_channels=config["model"]["in_channels"],
out_channels=config["model"]["out_channels"],
layers=config["model"]["layers"],
neurons=config["model"]["neurons"],
residual=config["model"]["residual"],
act=config["model"]["activation"],
num_scales=1)
optimizer = nn.Adam(model.trainable_params(), config["optimizer"]["initial_lr"])
problem = Poisson2D(model)
def forward_fn(pde_data, bc_outer_data, bc_inner_data, bc_inner_normal):
loss = problem.get_loss(pde_data, bc_outer_data, bc_inner_data, bc_inner_normal)
return loss
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
@jit
def train_step(pde_data, bc_outer_data, bc_inner_data, bc_inner_normal):
loss, grads = grad_fn(pde_data, bc_outer_data, bc_inner_data, bc_inner_normal)
loss = ops.depend(loss, optimizer(grads))
return loss
epochs = config["train_epochs"]
steps_per_epochs = train_dataset.get_dataset_size()
sink_process = ms.data_sink(train_step, train_dataset, sink_size=1)
for epoch in range(1, epochs + 1):
# train
time_beg = time.time()
model.set_train(True)
for _ in range(steps_per_epochs):
step_train_loss = sink_process()
print(f"epoch: {epoch} train loss: {step_train_loss} epoch time: {(time.time() - time_beg)*1000 :.3f} ms")
model.set_train(False)
if epoch % 100 == 0:
# eval
calculate_l2_error(model, inputs, label, config["train_batch_size"])
visual(model, inputs, label, epochs)
if __name__ == '__main__':
print("pid:", os.getpid())
start_time = time.time()
train()
print("End-to-End total time: {} s".format(time.time() - start_time))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/mindscience.git
git@gitee.com:mindspore/mindscience.git
mindspore
mindscience
mindscience
r0.6

搜索帮助