# GNGAN-PyTorch **Repository Path**: diyage_wxx/GNGAN-PyTorch ## Basic Information - **Project Name**: GNGAN-PyTorch - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2022-02-25 - **Last Updated**: 2022-03-21 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # Gradient Normalization for Generative Adversarial Networks Yi-Lun Wu, Hong-Han Shuai, Zhi-Rui Tam, Hong-Yu Chiu Paper: [https://arxiv.org/abs/2109.02235](https://arxiv.org/abs/2109.02235) This is the official implementation of Gradient Normalized GAN (GN-GAN). ## Requirements - Python 3.8.9 - Python packages ```sh # update `pip` for installing tensorboard. pip install -U pip setuptools pip install -r requirements.txt ``` ## Datasets - CIFAR-10 Pytorch build-in CIFAR-10 will be downloaded automatically. - STL-10 Pytorch build-in STL-10 will be downloaded automatically. - CelebA-HQ 128/256 We obtain celeba-hq from [this repository](https://github.com/suvojit-0x55aa/celebA-HQ-dataset-download) and preprocess it into `lmdb` file. - 256x256 ``` python dataset.py path/to/celebahq/256 ./data/celebahq/256 ``` - 128x128 We split data into train test splits by filenames, the test set contains images from `27001.jpg` to `30000.jpg`. ``` python dataset.py path/to/celebahq/128/train ./data/celebahq/128 ``` The folder structure: ``` ./data/celebahq ├── 128 │ ├── data.mdb │ └── lock.mdb └── 256 ├── data.mdb └── lock.mdb ``` - LSUN Church Outdoor 256x256 (training set) The folder structure: ``` ./data/lsun/church/ ├── data.mdb └── lock.mdb ``` ## Preprocessing Datasets for FID Pre-calculated statistics for FID can be downloaded [here](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing): - cifar10.train.npz - Training set of CIFAR10 - cifar10.test.npz - Testing set of CIFAR10 - stl10.unlabeled.48.npz - Unlabeled set of STL10 in resolution 48x48 - celebahq.3k.128.npz - Last 3k images of CelebA-HQ 128x128 - celebahq.all.256.npz - Full dataset of CelebA-HQ 256x256 - church.train.256.npz - Training set of LSUN Church Outdoor Folder structure: ``` ./stats ├── celebahq.3k.128.npz ├── celebahq.all.256.npz ├── church.train.256.npz ├── cifar10.test.npz ├── cifar10.train.npz └── stl10.unlabeled.48.npz ``` **NOTE** All the reported values (Inception Score and FID) in our paper are calculated by official implementation instead of our implementation. ## Training - Configuration files - We use `absl-py` to parse, save and reload the command line arguments. - All the configuration files can be found in `./config`. - The compatible configuration list is shown in the following table: |Script |Configurations|Multi-GPU| |-----------------|--------------|:-------:| |`train.py` |`GN-GAN_CIFAR10_CNN.txt`
`GN-GAN_CIFAR10_RES.txt`
`GN-GAN_CIFAR10_BIGGAN.txt`
`GN-GAN_STL10_CNN.txt`
`GN-GAN_STL10_RES.txt`
`GN-GAN-CR_CIFAR10_CNN.txt`
`GN-GAN-CR_CIFAR10_RES.txt`
`GN-GAN-CR_CIFAR10_BIGGAN.txt`
`GN-GAN-CR_STL10_CNN.txt`
`GN-GAN-CR_STL10_RES.txt`|| |`train_ddp.py`|`GN-GAN_CELEBAHQ128_RES.txt`
`GN-GAN_CELEBAHQ256_RES.txt`
`GN-GAN_CHURCH256_RES.txt`|:heavy_check_mark:| - Run the training script with the compatible configuration, e.g., - `train.py` supports training gan on `CIFAR10` and `STL10`, e.g., ```sh python train.py \ --flagfile ./config/GN-GAN_CIFAR10_RES.txt ``` - `train_ddp.py` is optimized for multi-gpu training, e.g., ``` CUDA_VISIBLE_DEVICES=0,1,2,3 python train_ddp.py \ --flagfile ./config/GN-GAN_CELEBAHQ256_RES.txt ``` - Generate images from checkpoints, e.g., `--eval`: evaluate best checkpoint. `--save PATH`: save the generated images to `PATH` ``` python train.py \ --flagfile ./logs/GN-GAN_CIFAR10_RES/flagfile.txt \ --eval \ --save path/to/generated/images ``` ## How to integrate Gradient Normalization into your work? The function `normalize_gradient` is implemented based on `torch.autograd` module, which can easily normalize your forward propagation of discriminator by updating a single line. ```python from torch.nn import BCEWithLogitsLoss from models.gradnorm import normalize_gradient net_D = ... # discriminator net_G = ... # generator loss_fn = BCEWithLogitsLoss() # Update discriminator x_real = ... # real data x_fake = net_G(torch.randn(64, 3, 32, 32)) # fake data pred_real = normalize_gradient(net_D, x_real) # net_D(x_real) pred_fake = normalize_gradient(net_D, x_fake) # net_D(x_fake) loss_real = loss_fn(pred_real, torch.ones_like(pred_real)) loss_fake = loss_fn(pred_fake, torch.zeros_like(pred_fake)) (loss_real + loss_fake).backward() # backward propagation ... # Update generator x_fake = net_G(torch.randn(64, 3, 32, 32)) # fake data pred_fake = normalize_gradient(net_D, x_fake) # net_D(x_fake) loss_fake = loss_fn(pred_fake, torch.ones_like(pred_fake)) loss.backward() # backward propagation ... ``` ## Citation If you find our work is relevant to your research, please cite: ``` @InProceedings{GNGAN_2021_ICCV, author = {Yi-Lun Wu, Hong-Han Shuai, Zhi Rui Tam, Hong-Yu Chiu}, title = {Gradient Normalization for Generative Adversarial Networks}, booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, month = {Oct}, year = {2021} } ```