【笔记】GAN、CGAN:GAN是生成器和判别器进行博弈;CGAN是GAN基础上实现图片分类
GAN的算法流程如下(没有按照GAN论文中设定G和D更新次数差异):2、Pytorch实现GAN网络(1)生成器生成器网络采用全连接网络,输入是潜变量z,输出是3通道的图像数据。pytorch提供的全连接操作是nn.Linear(),需要的参数是输入特征长度和期望的输出特征长度;激活函数使用LeakyReLU(0.2)。forward()函数中“x = x.view(x.size(0), -1)”
GAN的算法流程如下(没有按照GAN论文中设定G和D更新次数差异):
2、Pytorch实现GAN网络
(1)生成器
生成器网络采用全连接网络,输入是潜变量z,输出是3通道的图像数据。pytorch提供的全连接操作是nn.Linear(),需要的参数是输入特征长度和期望的输出特征长度;激活函数使用LeakyReLU(0.2)。
forward()函数中“x = x.view(x.size(0), -1)”目的是将输入的(batch_size,channels, height, width)图像变换为(batch_size,-1),这样才能输入到全连接层之中。
class Generator(nn.Module):
def __init__(self, z_dim, target_image_size):
"""
initialize
:param z_dim: latent z dim, like z=100
:param target_image_size: tuple (3, h, w)
"""
super().__init__()
self.view_image_size = target_image_size[0]*target_image_size[1]*target_image_size[2]
self.out_image_size = target_image_size
self.z_dim = z_dim
self.generator = nn.Sequential()
self.generator.add_module(name="0", module=nn.Linear(in_features=self.z_dim, out_features=256))
self.generator.add_module(name="1", module=nn.LeakyReLU(0.2))
self.generator.add_module(name="2", module=nn.Linear(in_features=256, out_features=512))
self.generator.add_module(name="3", module=nn.LeakyReLU(0.2))
self.generator.add_module(name="4", module=nn.Linear(in_features=512, out_features=self.view_image_size))
self.generator.add_module(name="5", module=nn.Tanh())
def forward(self, x):
x = x.view(x.size(0), -1)
out = self.generator(x)
out = out.view(x.size(0), *self.out_image_size)
return out
(2)判别器
判别器的输入是图像数据,包括真实的图像数据和生成的图像数据。判别器使用的仍然是全连接层,其最后一层的输出特征数为1,因为这是一个二分类问题。
class Discriminator(nn.Module):
def __init__(self, image_size):
"""
initialize
:param image_size: tuple (3, h, w)
"""
super().__init__()
self.in_image_size = image_size[0]*image_size[1]*image_size[2]
self.discriminator = nn.Sequential()
self.discriminator.add_module(name="0", module=nn.Linear(in_features=self.in_image_size, out_features=256))
self.discriminator.add_module(name="1", module=nn.LeakyReLU(0.2))
self.discriminator.add_module(name="2", module=nn.Linear(in_features=256, out_features=64))
self.discriminator.add_module(name="3", module=nn.LeakyReLU(0.2))
self.discriminator.add_module(name="4", module=nn.Linear(in_features=64, out_features=1))
self.discriminator.add_module(name="5", module=nn.Sigmoid())
def forward(self, x):
x = x.view(x.size(0), -1)
out = self.discriminator(x)
return out
(3)损失函数和优化器
由于GAN包含两个网络。所以需要定义两个优化器分别进行优化。优化时固定一个网络的参数更新另一个。损失函数选择的BCELoss是一种较为常用的二分类损失,因为判别器进行真假的判别实质就是二分类问题:
adversarial_loss = torch.nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=options.lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=options.lr, betas=(0.5, 0.999))
(4)标签值和噪声值
判别器进行分类是一种有监督的分类,需要对应的标签值。下面定义的real_label是全为1的标签,fake_label是全为0的标签。real_label和fake_label是相对的概念,判别器使用real_label作为真实图片的标签,使用fake_label作为伪造样本的标签;生成器使用real_label作为自己伪造样本的标签。
real_label = torch.ones(size=(options.batch_size, 1), requires_grad=False).to(device)
fake_label = torch.zeros(size=(options.batch_size, 1), requires_grad=False).to(device)
噪声值需要使用均值为0,方差为1的标准正态分布如下:
z = torch.randn(size=(options.batch_size, options.z_dim)).to(device)
(5)网络训练
网络的训练关键是如何交替的更新生成器和判别器,以及如何正确的更新梯度。与普通的CNN分类或者回归的单个网络不同,GAN需要训练两个网络:训练生成器G时判别器D的参数固定不变,训练判别器D时生成器参数固定不变。
for _ in range(options.total_iterations):
for image, _ in data_loader:
# 1:数据准备
image = image.to(device)
z = torch.randn(size=(options.batch_size, options.z_dim)).to(device)
# #################################################
# 2:训练生成器
optimizer_G.zero_grad()
# 2.1:生成伪造样本
generated_image = generator(z)
# 2.2:计算判别器对伪造样本的输出的为真样本的概率值
d_out_fake = discriminator(generated_image)
# 2.3:计算生成器伪造样本不被认为是真的损失
g_loss = bce_loss(d_out_fake, real_label)
# 2.4:更新生成器
g_loss.backward()
optimizer_G.step()
# #################################################
# 3:训练判别器
optimizer_D.zero_grad()
# 3.1:计算判别器对真实样本给出为真的概率
d_out_real = discriminator(image)
# 3.2:计算判别器对真实样本的su's
real_loss = bce_loss(d_out_real, real_label)
# 3.3:计算判别器
d_out_fake = discriminator(generated_image.detach())
fake_loss = bce_loss(d_out_fake, fake_label)
d_loss = real_loss + fake_loss
# 3.4:更新判别器参数
d_loss.backward()
optimizer_D.step()
# 4:记录损失
record_iter.append(iteration)
record_g_loss.append(g_loss.item())
record_d_loss.append(d_loss.item())
iteration += 1
# #################################################
# 5:打印损失,保存图片
if iteration % 100 == 0:
with torch.no_grad():
generator.eval()
fixed_image = generator(fixed_z)
generator.train()
print("[iter: {}], [G loss: {}], [D loss: {}]".format(iteration, g_loss.item(), d_loss.item()))
save.save_image(image_tensor=fixed_image[0].squeeze(), out_name="results/"+str(iteration)+".jpg")
3、可视化结果分析
生成的动漫人脸如下所示,可以看出存在大量的噪声,虽然能看出脸的大致轮廓,原因可能是生成器网络拟合能力不足,也可能是判别器的原因,或者是损失函数选择不对:
损失函数的曲线如下,可以看出生成器损失和判别器损失始终都在震荡之中,互相博弈:
CGAN:
上文说到生成对抗网络GAN能够通过训练学习到数据分布,进而生成新的样本。可是GAN的缺点是生成的图像是随机的,不能控制生成图像属于何种类别。比如数据集包含飞机、汽车和房屋等类别,原始GAN并不能在测试阶段控制输出属于哪一类。
为此,研究人员提出了Conditional Generative Adversarial Network(简称CGAN),CGAN的图像生成过程是可控的。
本文包含以下3个方面:
(1)CGAN原理分析
(2)pytorch实现CGAN
(3)视觉结果和损失函数曲线
CGAN的思想是非常简单的,这也验证了那句话,越简单的想法越伟大!
1、CGAN原理分析
1.1 网络结构
CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息,实现条件生成模型。CGAN原文中作者说额外的条件信息可以是类别标签或者其它的辅助信息,本文使用条件信息(记为y)作为例子。
CGAN的核心操作是将条件信息加入到G和D中,下面分别进行讨论:
(1)原始GAN生成器输入是噪声信号,类别标签可以和噪声信号组合作为隐空间表示;
(2)原始GAN判别器输入是图像数据(真实图像和生成图像),同样需要将类别标签和图像数据进行拼接作为判别器输入。
从上图(来自CGAN论文)中可以看出,CGAN的网络相对于原始GAN网络并没有变化,改变的仅仅是生成器G和判别器D的输入数据,这就使得CGAN可以作为一种通用策略嵌入到其它的GAN网络中。
2.2 损失函数
原始GAN包含一个生成器和一个判别器,其中生成器G和判别器D进行极大极小博弈,损失函数如下:
CGAN添加的额外信息y只需要和x与z进行合并,作为G和D的输入即可,由此得到了CGAN的损失函数如下:
1.3 训练策略与实验结果
CGAN在mnist数据集上进行了实验,对于生成器:使用数字的类别y作为标签,并进行了one-hot编码,噪声z来自均均匀分布;噪声z映射到200维的隐层,类别标签映射到1000维的隐层,然后进行拼接作为下一层的输入,激活函数使用ReLU;最后一层使用Sigmoid函数,生成的样本为784维(使用的mnist长宽为28x28=784)。得到的实验结果如下:
上图中每行是由相同的标签生成的,说明CGAN的确可以通过给生成器特定的标签,实现特定模式(类别)的生成。CGAN还做了其它的实验,都证明了CGAN的模式控制能力。
2、pytorch实现
2.1 生成器实现
CGAN的生成器输入为噪声z和类别标签y的联合输入,所以这里我直接在对DCGAN的生成器进行改动(DCGAN的代码和分析参见我之前的文章):
class Generator(nn.Module):
def __init__(self, z_dim, num_classes):
super().__init__()
self.z_dim = z_dim
self.num_classes = num_classes
net = []
# 1:设定每次反卷积的输入和输出通道数
# 卷积核尺寸固定为3,反卷积输出为“SAME”模式
channels_in = [self.z_dim+self.num_classes, 512, 256, 128, 64]
channels_out = [512, 256, 128, 64, 3]
active = ["R", "R", "R", "R", "tanh"]
stride = [1, 2, 2, 2, 2]
padding = [0, 1, 1, 1, 1]
for i in range(len(channels_in)):
net.append(nn.ConvTranspose2d(in_channels=channels_in[i], out_channels=channels_out[i],
kernel_size=4, stride=stride[i], padding=padding[i], bias=False))
if active[i] == "R":
net.append(nn.BatchNorm2d(num_features=channels_out[i]))
net.append(nn.ReLU())
elif active[i] == "tanh":
net.append(nn.Tanh())
self.generator = nn.Sequential(*net)
def forward(self, x, label):
x = x.unsqueeze(2).unsqueeze(3)
label = label.unsqueeze(2).unsqueeze(3)
data = torch.cat(tensors=(x, label), dim=1)
out = self.generator(data)
return out
2.2 判别器的实现
CGAN的判别器需要使用图像(生成的和真实的)和类别标签y联合输入,所以这里也是对DCGAN的判别器第一层进行改动:
class Discriminator(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
net = []
# 1:预先定义
channels_in = [3+self.num_classes, 64, 128, 256, 512]
channels_out = [64, 128, 256, 512, 1]
padding = [1, 1, 1, 1, 0]
active = ["LR", "LR", "LR", "LR", "sigmoid"]
for i in range(len(channels_in)):
net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i],
kernel_size=4, stride=2, padding=padding[i], bias=False))
if i == 0:
net.append(nn.LeakyReLU(0.2))
elif active[i] == "LR":
net.append(nn.BatchNorm2d(num_features=channels_out[i]))
net.append(nn.LeakyReLU(0.2))
elif active[i] == "sigmoid":
net.append(nn.Sigmoid())
self.discriminator = nn.Sequential(*net)
def forward(self, x, label):
label = label.unsqueeze(2).unsqueeze(3)
label = label.repeat(1, 1, x.size(2), x.size(3))
data = torch.cat(tensors=(x, label), dim=1)
out = self.discriminator(data)
out = out.view(data.size(0), -1)
return out
3、视觉结果和损失函数曲线
自己的数据包含3类:动漫脸、人脸、鞋。其实当时还选择了其它数据,但是最后发现,在数据集质量不够高时,生成的样本明显不够好,最后筛选才确定了使用这三个数据集。当然,自己的实验结果也非常差!迭代的总体次数为6000次左右,生成了下面的样本:
上面这个动漫脸完全看不清,人脸中也看不见嘴,下面这个结果更好些:
实际上,结果比较差的主要原因还是在于生成器的结构(不够深,拟合能力不够强),如果换成是近两年的生成器结构,生成的效果肯定会好很多。当然,调参数而是很重要的一个方面,自己也没有进行细致的调参。下面这张图显示了迭代过程中生成的图像的变化:
损失函数没有展示出收敛的趋势,尤其是生成器的损失似乎还在增加:
更多推荐
所有评论(0)