1 Star 2 Fork 0

xurui-ryan/improved-wgan-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
training_utils.py 3.32 KB
一键复制 编辑 原始数据 按行查看 历史
Hung Nguyen 提交于 2019-11-17 17:22 +08:00 . Merged master
from pathlib import Path
from collections import OrderedDict
import gpustat
import torch
from torch import autograd
import torch.nn.init as init
from torchvision import transforms, datasets
from torch.autograd import grad
from models.wgan import *
def mkdir_path(path):
path.mkdir(parents=True, exist_ok=True)
def showMemoryUsage(device=1):
gpu_stats = gpustat.GPUStatCollection.new_query()
item = gpu_stats.jsonify()["gpus"][device]
print('Used/total: ' + "{}/{}".format(item["memory.used"], item["memory.total"]))
def weights_init(m):
if isinstance(m, MyConvo2d):
if m.conv.weight is not None:
if m.he_init:
init.kaiming_uniform_(m.conv.weight)
else:
init.xavier_uniform_(m.conv.weight)
if m.conv.bias is not None:
init.constant_(m.conv.bias, 0.0)
if isinstance(m, nn.Linear):
if m.weight is not None:
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0.0)
def remove_module_str_in_state_dict(state_dict):
state_dict_rename = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "") # remove `module.`
state_dict_rename[name] = v
return state_dict_rename
def load_data(image_data_type, path_to_folder, data_transform, batch_size, classes=None, num_workers=5):
# torch issue
# https://github.com/pytorch/pytorch/issues/22866
torch.set_num_threads(1)
if image_data_type == 'lsun':
dataset = datasets.LSUN(path_to_folder, classes=classes, transform=data_transform)
elif image_data_type == "image_folder":
dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform)
else:
raise ValueError("Invalid image data type")
dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True)
return dataset_loader
def generate_image(netG, dim, batch_size, noise=None):
if noise is None:
noise = gen_rand_noise()
with torch.no_grad():
noisev = noise
samples = netG(noisev)
samples = samples.view(batch_size, 3, dim, dim)
samples = samples * 0.5 + 0.5
return samples
def gen_rand_noise(batch_size, ):
noise = torch.randn(batch_size, 128)
return noise
def calc_gradient_penalty(netD, real_data, fake_data, batch_size, dim, device, gp_lambda):
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous()
alpha = alpha.view(batch_size, 3, dim, dim)
alpha = alpha.to(device)
fake_data = fake_data.view(batch_size, 3, dim, dim)
interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
interpolates = interpolates.to(device)
interpolates.requires_grad_(True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gp_lambda
return gradient_penalty
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/xurui-ryan/improved-wgan-pytorch.git
git@gitee.com:xurui-ryan/improved-wgan-pytorch.git
xurui-ryan
improved-wgan-pytorch
improved-wgan-pytorch
master

搜索帮助