208 Star 856 Fork 632

GVPMindSpore / mindscience

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
README_CN.md 7.82 KB
一键复制 编辑 原始数据 按行查看 历史

ENGLISH | 简体中文

目录

ENSO 描述

厄尔尼诺与南方涛动(ENSO)现象对区域生态系统影响较大,因此,准确的ENSO的预测带来了巨大的区域效益。 然而,对ENSO超过一年的预测仍然存在问题。最近,卷积神经网络(CNN)已被证明是预测ENSO的有效工具。

在这个模型中,我们实现了CNN的训练和评估过程,用于用气象数据预测ENSO。

论文: Ham, Y.-G., J.-H. Kim, and J.-J. Luo, 2019: Deep learning for multi-year ENSO forecasts. Nature, 573, 568–572.

数据集

用于训练的数据集和预训练checkpoints文件将会在首次启动时自动下载。

  • 数据格式: .npy文件
    • 注: 数据会在process.py中处理
  • 数据集在./data目录下,目录结构如下:
├── data
│   ├── htmp_data
│   ├── train_data
│   │   ├── ACCESS-CM2
│   │   ├── CCSM4
│   │   ├── CESM1-CAM5
│   │   ├── ...
│   │   └── obs
│   └── var_data

您如果需要手动下载数据集或checkpoints文件, 请访问此链接

环境要求

快速开始

通过官网安装好MindSpore和上面需要的数据集后,就可以开始训练和验证如下:

  • 在 Ascend 或 GPU 上运行

默认:

python train.py

完整命令:

python train.py \
    --save_ckpt true \
    --load_ckpt false \
    --save_ckpt_path ./checkpoints \
    --load_ckpt_path ./checkpoints/exp2_aftertrain/enso_float16.ckpt \
    --save_data true\
    --load_data_path ./data \
    --save_data_path ./data \
    --save_figure true \
    --figures_path ./figures \
    --log_path ./logs \
    --print_interval 10 \
    --lr 0.01 \
    --epochs 20 \
    --batch_size 400 \
    --skip_aftertrain false \
    --epochs_after 5 \
    --batch_size_after 30 \
    --lr_after 1e-6 \
    --download_data enso \
    --force_download false \
    --amp_level O3 \
    --device_id 0 \
    --mode 0

脚本说明

脚本和示例代码

├── enso
│   ├── checkpoints                       # checkpoint文件
│   ├── data                              # 数据文件
│   │   ├── htmp_data                     # 验证结果的保存路径
│   │   ├── var_data                      # 验证数据集
│   │   └── train_data                    # 训练数据集
│   ├── figures                           # 结果图片
│   ├── logs                              # 日志文件
│   ├── src                               # 源代码
│   │   ├── network.py                    # 网络架构
│   │   ├── plot.py                       # 绘制结果
│   │   └── process.py                    # 数据处理
│   ├── config.yaml                       # 超参数配置
│   ├── README.md                         # 英文模型说明
│   ├── README_CN.md                      # 中文模型说明
│   ├── train.py                          # python训练脚本
│   └── eval.py                           # python评估脚本

脚本参数

train.py中的重要参数如下:

参数名 描述 默认值
save_ckpt 是否保存checkpoint true
load_ckpt 是否加载checkpoint false
save_ckpt_path checkpoint保存路径 ./checkpoints
load_ckpt_path checkpoint加载路径 ./checkpoints/exp2_aftertrain/enso_float16.ckpt
save_data 是否保存数据 true
load_data_path 加载数据的路径 ./data
save_data_path 保存数据的路径 ./data
save_figure 是否保存和绘制图片 true
figures_path 图片保存路径 ./figures
log_path 日志保存路径 ./logs
print_interval 时间与loss打印间隔 10
lr 学习率 0.01
epochs 时期(迭代次数) 20
batch_size 数据集的大小 400
skip_aftertrain 是否跳过训练后的流程 false
epochs_after 训练后流程的时期(迭代次数) 5
batch_size_after 训练后流程的数据集大小 30
lr_after 训练后流程的学习率 1e-6
download_data 模型所需数据集与(或)checkpoints enso
force_download 是否强制下载数据 false
amp_level MindSpore自动混合精度等级 O3
device_id 需要设置的设备号 None
mode MindSpore静态图模式(0)或动态图模式(1) 0

训练流程

# python train.py
...
epoch: 1 step: 1, loss is 0.9130635857582092
epoch: 1 step: 2, loss is 1.0354164838790894
epoch: 1 step: 3, loss is 0.8914494514465332
epoch: 1 step: 4, loss is 0.9377754330635071
epoch: 1 step: 5, loss is 1.0472232103347778
epoch: 1 step: 6, loss is 1.0421113967895508
epoch: 1 step: 7, loss is 1.100639820098877
epoch: 1 step: 8, loss is 0.9849204421043396
...
  • 训练结束后,您仍然可以通过保存在log_path下面的日志文件回顾训练过程,默认为./logs目录中。

  • 模型checkpoint将保存在 save_ckpt_path中,默认为./checkpoints 目录中。

推理流程

在运行下面的命令之前,请检查使用的config.yaml 中的checkpoint加载路径load_ckpt_path 进行推理。

python eval.py

您可以通过日志文件log_path查看过程与结果,默认位于./logs 。 结果图片存放于figures_path中,默认位于./figures

1
https://gitee.com/mindspore/mindscience.git
git@gitee.com:mindspore/mindscience.git
mindspore
mindscience
mindscience
r0.5

搜索帮助