13 Star 61 Fork 279

Ascend/ModelZoo-TensorFlow

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

基本信息

应用领域(Application Domain): Image Synthesis

版本(Version):1.1

修改时间(Modified) :2022.04.11

大小(Size):6.9M

框架(Framework):TensorFlow_2.4.1

模型格式(Model Format):ckpt

精度(Precision):Mixed

应用级别(Categories):Research

描述(Description):基于wasserstein loss的生成对抗网络

概述

传统GAN网络理论上来说,如果两个分布不相交,则JS散度将不再是连续的,因此将不可微,从而导致梯度为0。WGAN通过使用wasserstein loss解决了这个问题,使得loss函数在任何地方都连续且可微。

默认配置

  • 主要训练超参(单卡):
    • batch_size: 128
    • epochs: 400
    • lr: 0.001

支持特性

特性列表 是否支持
分布式训练
混合精度
数据并行

混合精度训练

昇腾910 AI处理器提供自动混合精度功能,可以针对全网中float32数据类型的算子,按照内置的优化策略,自动将部分float32的算子降低精度到float16,从而在精度损失很小的情况下提升系统性能并减少内存使用。

开启混合精度

  npu_device.global_options().precision_mode='allow_mix_precision'
  npu_device.open().as_default()

训练环境准备

pip3 install requirements.txt

说明:依赖配置文件requirements.txt文件位于模型的根目录

快速上手

数据集准备

  1. 用户需自行下载MNIST训练数据集,应有如下结构
    cifar10/
    ├── mnist.npz
    ├── t10k-images.idx3-ubyte
    ├── t10k-labels.idx3-ubyte
    ├── train-images.idx3-ubyte
    ├── train-labels.idx3-ubyte
    └── ...
    

模型训练

  • 单击“立即下载”,并选择合适的下载方式下载源码包。

  • 开始训练

    1. 启动训练之前,首先要配置程序运行相关环境变量。

      环境变量配置信息参见:

      Ascend 910训练平台环境变量设置

    2. 单卡训练

      2.1 配置train_full_1p.sh脚本中data_path(脚本路径GAN_ID2351_for_TensorFlow2.X/test/train_full_1p.sh),请用户根据实际路径配置,数据集参数如下所示:

       --data_path=/home/MNIST
      

      2.2 1p指令如下:

       bash train_full_1p.sh --data_path=/home/MNIST
      

迁移学习指导

  • 数据集准备。

    1. 获取数据。 请参见“快速上手”中的数据集准备。
  • 模型训练。

    参考“模型训练”中训练步骤。

  • 模型评估。

    参考“模型训练”中验证步骤。

高级参考

脚本和示例代码

convmixer_ID2501_for_TensorFlow2.X/
├── LICENSE
├── modelzoo_level.txt
├── README.md
├── requirements.txt
├── tf_v2_03_WGAN.py
├── test
│   ├── train_full_1p.sh
│   ├── train_performance_1p_static_eval.sh
│   ├── train_performance_1p_dynamic_eval.sh

脚本参数

--data_path       训练数据集路径
--train_epochs         训练epoch设置
--batch_size     训练bs设置

训练过程

  1. 通过“模型训练”中的训练指令启动单卡训练。
  2. 将训练脚本(train_full_1p.sh)中的data_path设置为训练数据集的路径。具体的流程参见“模型训练”的示例。
  3. 模型存储路径为“curpath/output/ASCEND_DEVICE_ID”,包括训练的log文件。
  4. 以多卡训练为例,loss信息在文件curpath/output/{ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log中。

推理/验证过程

 NA

马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/ModelZoo-TensorFlow.git
git@gitee.com:ascend/ModelZoo-TensorFlow.git
ascend
ModelZoo-TensorFlow
ModelZoo-TensorFlow
master

搜索帮助