代码拉取完成,页面将自动刷新
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train process"""
import argparse
import os
import time
import numpy as np
from mindspore import nn, ops, set_seed, jit, context
import mindspore as ms
from mindflow.cell import MultiScaleFCSequential
from mindflow.utils import load_yaml_config
from src import create_training_dataset, create_test_dataset, calculate_l2_error, visual, Poisson2D
set_seed(123456)
np.random.seed(123456)
parser = argparse.ArgumentParser(description="poisson2D on a ring")
parser.add_argument("--mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"],
help="Running in GRAPH_MODE OR PYNATIVE_MODE")
parser.add_argument("--save_graphs", type=bool, default=False, choices=[True, False],
help="Whether to save intermediate compilation graphs")
parser.add_argument("--save_graphs_path", type=str, default="./graphs")
parser.add_argument("--device_target", type=str, default="GPU", choices=["GPU", "Ascend"],
help="The target device to run, support 'Ascend', 'GPU'")
parser.add_argument("--device_id", type=int, default=0, help="ID of the target device")
parser.add_argument("--config_file_path", type=str, default="./cylinder_flow.yaml")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE if args.mode.upper().startswith("GRAPH") else context.PYNATIVE_MODE,
save_graphs=args.save_graphs,
save_graphs_path=args.save_graphs_path,
device_target=args.device_target,
device_id=args.device_id)
print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
def train():
'''train and evaluate the network'''
# load configurations
config = load_yaml_config('poisson2d_cfg.yaml')
# create training dataset
dataset = create_training_dataset(config)
train_dataset = dataset.batch(batch_size=config["train_batch_size"])
# create test dataset
inputs, label = create_test_dataset(config)
# define models and optimizers
model = MultiScaleFCSequential(in_channels=config["model"]["in_channels"],
out_channels=config["model"]["out_channels"],
layers=config["model"]["layers"],
neurons=config["model"]["neurons"],
residual=config["model"]["residual"],
act=config["model"]["activation"],
num_scales=1)
optimizer = nn.Adam(model.trainable_params(), config["optimizer"]["initial_lr"])
problem = Poisson2D(model)
def forward_fn(pde_data, bc_outer_data, bc_inner_data, bc_inner_normal):
loss = problem.get_loss(pde_data, bc_outer_data, bc_inner_data, bc_inner_normal)
return loss
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
@jit
def train_step(pde_data, bc_outer_data, bc_inner_data, bc_inner_normal):
loss, grads = grad_fn(pde_data, bc_outer_data, bc_inner_data, bc_inner_normal)
loss = ops.depend(loss, optimizer(grads))
return loss
epochs = config["train_epochs"]
steps_per_epochs = train_dataset.get_dataset_size()
sink_process = ms.data_sink(train_step, train_dataset, sink_size=1)
for epoch in range(1, epochs + 1):
# train
time_beg = time.time()
model.set_train(True)
for _ in range(steps_per_epochs):
step_train_loss = sink_process()
print(f"epoch: {epoch} train loss: {step_train_loss} epoch time: {(time.time() - time_beg)*1000 :.3f} ms")
model.set_train(False)
if epoch % 100 == 0:
# eval
calculate_l2_error(model, inputs, label, config["train_batch_size"])
visual(model, inputs, label, epochs)
if __name__ == '__main__':
print("pid:", os.getpid())
start_time = time.time()
train()
print("End-to-End total time: {} s".format(time.time() - start_time))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。