应用领域(Application Domain): Image Synthesis
版本(Version):1.1
修改时间(Modified) :2022.04.11
大小(Size):6.9M
框架(Framework):TensorFlow_2.4.1
模型格式(Model Format):ckpt
精度(Precision):Mixed
应用级别(Categories):Research
描述(Description):基于wasserstein loss的生成对抗网络
传统GAN网络理论上来说,如果两个分布不相交,则JS散度将不再是连续的,因此将不可微,从而导致梯度为0。WGAN通过使用wasserstein loss解决了这个问题,使得loss函数在任何地方都连续且可微。
参考论文:
参考实现:
适配昇腾 AI 处理器的实现:
通过Git获取对应commit_id的代码方法如下:
git clone {repository_url} # 克隆仓库的代码
cd {repository_name} # 切换到模型的代码仓目录
git checkout {branch} # 切换到对应分支
git reset --hard {commit_id} # 代码设置到对应的commit_id
cd {code_path} # 切换到模型代码所在路径,若仓库下只有该模型,则无需切换
特性列表 | 是否支持 |
---|---|
分布式训练 | 否 |
混合精度 | 是 |
数据并行 | 否 |
昇腾910 AI处理器提供自动混合精度功能,可以针对全网中float32数据类型的算子,按照内置的优化策略,自动将部分float32的算子降低精度到float16,从而在精度损失很小的情况下提升系统性能并减少内存使用。
npu_device.global_options().precision_mode='allow_mix_precision'
npu_device.open().as_default()
pip3 install requirements.txt
说明:依赖配置文件requirements.txt文件位于模型的根目录
cifar10/
├── mnist.npz
├── t10k-images.idx3-ubyte
├── t10k-labels.idx3-ubyte
├── train-images.idx3-ubyte
├── train-labels.idx3-ubyte
└── ...
单击“立即下载”,并选择合适的下载方式下载源码包。
开始训练
启动训练之前,首先要配置程序运行相关环境变量。
环境变量配置信息参见:
单卡训练
2.1 配置train_full_1p.sh脚本中data_path
(脚本路径GAN_ID2351_for_TensorFlow2.X/test/train_full_1p.sh),请用户根据实际路径配置,数据集参数如下所示:
--data_path=/home/MNIST
2.2 1p指令如下:
bash train_full_1p.sh --data_path=/home/MNIST
数据集准备。
模型训练。
参考“模型训练”中训练步骤。
模型评估。
参考“模型训练”中验证步骤。
convmixer_ID2501_for_TensorFlow2.X/
├── LICENSE
├── modelzoo_level.txt
├── README.md
├── requirements.txt
├── tf_v2_03_WGAN.py
├── test
│ ├── train_full_1p.sh
│ ├── train_performance_1p_static_eval.sh
│ ├── train_performance_1p_dynamic_eval.sh
--data_path 训练数据集路径
--train_epochs 训练epoch设置
--batch_size 训练bs设置
NA
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。