代码拉取完成,页面将自动刷新
from pathlib import Path
from collections import OrderedDict
import gpustat
import torch
from torch import autograd
import torch.nn.init as init
from torchvision import transforms, datasets
from torch.autograd import grad
from models.wgan import *
def mkdir_path(path):
path.mkdir(parents=True, exist_ok=True)
def showMemoryUsage(device=1):
gpu_stats = gpustat.GPUStatCollection.new_query()
item = gpu_stats.jsonify()["gpus"][device]
print('Used/total: ' + "{}/{}".format(item["memory.used"], item["memory.total"]))
def weights_init(m):
if isinstance(m, MyConvo2d):
if m.conv.weight is not None:
if m.he_init:
init.kaiming_uniform_(m.conv.weight)
else:
init.xavier_uniform_(m.conv.weight)
if m.conv.bias is not None:
init.constant_(m.conv.bias, 0.0)
if isinstance(m, nn.Linear):
if m.weight is not None:
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0.0)
def remove_module_str_in_state_dict(state_dict):
state_dict_rename = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "") # remove `module.`
state_dict_rename[name] = v
return state_dict_rename
def load_data(image_data_type, path_to_folder, data_transform, batch_size, classes=None, num_workers=5):
# torch issue
# https://github.com/pytorch/pytorch/issues/22866
torch.set_num_threads(1)
if image_data_type == 'lsun':
dataset = datasets.LSUN(path_to_folder, classes=classes, transform=data_transform)
elif image_data_type == "image_folder":
dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform)
else:
raise ValueError("Invalid image data type")
dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True)
return dataset_loader
def generate_image(netG, dim, batch_size, noise=None):
if noise is None:
noise = gen_rand_noise()
with torch.no_grad():
noisev = noise
samples = netG(noisev)
samples = samples.view(batch_size, 3, dim, dim)
samples = samples * 0.5 + 0.5
return samples
def gen_rand_noise(batch_size, ):
noise = torch.randn(batch_size, 128)
return noise
def calc_gradient_penalty(netD, real_data, fake_data, batch_size, dim, device, gp_lambda):
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous()
alpha = alpha.view(batch_size, 3, dim, dim)
alpha = alpha.to(device)
fake_data = fake_data.view(batch_size, 3, dim, dim)
interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
interpolates = interpolates.to(device)
interpolates.requires_grad_(True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gp_lambda
return gradient_penalty
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。