代码拉取完成,页面将自动刷新
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
class Generator(nn.Module):
def __init__(self,latent_dim,img_shape):
super(Generator,self).__init__()
self.img_shape= img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=True),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
# img:[batch_size,img_size,img_size], z:[batch_size,latent_dim]
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return (img,z)
class Encoder(nn.Module):
def __init__(self,latent_dim,img_shape):
super(Encoder,self).__init__()
self.img_shape= img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=False))
return layers
self.model = nn.Sequential(
nn.Linear(int(np.prod(self.img_shape)),1024),
*block(1024, 512, normalize=True),
*block(512, 256, normalize=True),
*block(256, 128, normalize=True),
*block(128, latent_dim, normalize=True),
nn.Tanh()
)
def forward(self, img):
# img:[batch_size,img_size,img_size], z:[batch_size,latent_dim]
z = self.model(img)
img = img.view(img.size(0), *self.img_shape)
return (img,z)
class Discriminator(nn.Module):
def __init__(self, latent_dim,img_shape):
super(Discriminator, self).__init__()
joint_shape = latent_dim + np.prod(img_shape)
self.model = nn.Sequential(
nn.Linear(joint_shape, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1),
)
def forward(self, img, z):
# img:[batch_size,img_size,img_size], z:[batch_size,latent_dim]
joint = torch.cat((img.view(img.size(0),-1),z),dim=1)
validity = self.model(joint)
return validity
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。