1 Star 0 Fork 0

lekmao / BidirectionalGAN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
modules.py 2.70 KB
一键复制 编辑 原始数据 按行查看 历史
jaeho 提交于 2020-11-24 16:47 . solved mode collapse
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
class Generator(nn.Module):
def __init__(self,latent_dim,img_shape):
super(Generator,self).__init__()
self.img_shape= img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=True),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
# img:[batch_size,img_size,img_size], z:[batch_size,latent_dim]
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return (img,z)
class Encoder(nn.Module):
def __init__(self,latent_dim,img_shape):
super(Encoder,self).__init__()
self.img_shape= img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=False))
return layers
self.model = nn.Sequential(
nn.Linear(int(np.prod(self.img_shape)),1024),
*block(1024, 512, normalize=True),
*block(512, 256, normalize=True),
*block(256, 128, normalize=True),
*block(128, latent_dim, normalize=True),
nn.Tanh()
)
def forward(self, img):
# img:[batch_size,img_size,img_size], z:[batch_size,latent_dim]
z = self.model(img)
img = img.view(img.size(0), *self.img_shape)
return (img,z)
class Discriminator(nn.Module):
def __init__(self, latent_dim,img_shape):
super(Discriminator, self).__init__()
joint_shape = latent_dim + np.prod(img_shape)
self.model = nn.Sequential(
nn.Linear(joint_shape, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1),
)
def forward(self, img, z):
# img:[batch_size,img_size,img_size], z:[batch_size,latent_dim]
joint = torch.cat((img.view(img.size(0),-1),z),dim=1)
validity = self.model(joint)
return validity
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lekmao/BidirectionalGAN.git
git@gitee.com:lekmao/BidirectionalGAN.git
lekmao
BidirectionalGAN
BidirectionalGAN
main

搜索帮助

344bd9b3 5694891 D2dac590 5694891