生成式对抗网络(GAN)
一、什么是生成式对抗网络GAN?
在知乎上看到一个比较有趣的例子:
女生让男生给自己拍照,可是一直不满意男生拍的照片,就对照“别人家的男朋友”拍的照片,一次次让男生去改,直到女生满意。
在这个例子中,
-
男生可以被看作是GAN中的生成模型(Generative Model);
-
女生可以被看作是GAN中的判别模型(Discriminator);
-
整个拍照的过程可以被看作是博弈式的训练过程
-
男生(生成模型)的目的:拍出女朋友满意的照片(生成一幅和真实图片极其相似的图片)
-
女生(判别模型)的目的:分辨男朋友拍的照片,不满意的打回去(判别生成图片与真实图片是否相似,如果不够相似,打回去)
上述博弈过程,如果采用神经网络作为模型类型,则被称为生成式对抗网络(GAN)
正如视频中提到的两个问题:
-
为什么罪犯制造的假币越来越逼真?
为什么GAN可以生成数据?
二、GAN的详细介绍
GAN的框架
判别器D(Discriminator):区分真实样本和虚假样本。D是一个神经网络,经过运算后,如果是真实的图片,给出real(1);如果是假的图片,给出fake(0)
随机噪声z:从一个先验分布(人为定义,一般是均匀分布或者正态分布)中随机采样的向量
真实样本x:从数据库中采样的样本
合成样本G(z):生成模型G输出的样本
生成器G(Generator):欺骗判别器。生成虚假数据,使得判别器D能够尽可能给出高的评分。生成器不断改变自己,直到生成的很多图片能够欺骗判别器
GAN目标函数
训练算法:
1.随机初始化生成器和判别器
2.交替训练判别器D和生成器G,直到收敛
- 步骤一:固定生成器G(不优化),训练判别器D区分真实图像与合成图像(赋予真实图像高分,赋予合成图像低分)(用监督训练二分类问题)
- 步骤二:固定判别器D,训练生成器G欺骗判别器D(更新生成器的参数,使其合成的图片被生成器D赋予高分)(最大化问题)
训练一个生成模型
一个能够生成我们想要的数据的模型(图模型、函数、神经网络)
GAN通过一个低维向量 生成器(全连接神经网络)
cGAN生成可控的数据 生成器(全连接神经网络)
DCGAN 生成器(卷积神经网络)
WGAN 生成器(WGAN)重新设计目标函数,训练更稳定,生成数据质量更棒
KL散度和JS散度
-
KL散度(Kullback-Leibler divergence)
一种衡量两个概率分布的匹配程度的指标,又称为KL距离,相对熵
当P(x)和Q(x)的相似度越高,KL散度越小
KL散度主要有两个性质:
(1)不对称性
(2)非负性
KL散度本质是用来衡量两个概率分布的差异一种数学计算方式;由于用到比值除法不具备对称性。
神经网络训练时为何不用KL散度,从数学上来讲,KL散度多减了一个H(P);P代表真实分布,Q代表估计的分布
极大似然估计等价于最小化生成数据分布和真实分布的KL散度
-
JS散度(Jensen-Shannon divergence)
JS散度也称为JS距离,是KL散度的一种变形
JS散度主要性质:
(1)值域范围(JS散度的值域范围是[0,1],相同是0,相反为1)
(2)对称性
(3)交叉熵
很多情况下,假设数据符合高斯分布是不合理的,数据分布是无法用公式显示的写出来的
因此用高斯模型去拟合数据分布,我们需要一个更通用的生成模型,可以拟合任意数据分布,如下
GAN:生成式对抗网络通过对抗训练,间接计算出散度JS,使得模型可以优化
GAN做的事情:
1.最大化判别器损失,等价于计算合成数据分布和真实数据分布的JS散度
2.最小化生成器损失,等价于最小化JS散度(也就是优化生成模型 )
三、DCGAN
四、代码练习
(一)GAN
- 通过make_moons生成双半月形的数据,同时把数据点画出来
-
定义生成器、判别器、优化器
判别器中使用了sigmoid函数(可能是因为需要判别生成的图片是否是真实图片,即相当于是一个二分类的问题,因此用sigmoid函数)
优化器选择的是adam
-
对抗训练
整个对抗训练可以分为两部分:
- 第一部分(固定生成器G,改进判别器D)
- 第二部分(固定判别器D,改进生成器G)
-
修改learning_rate和batch_size
学习率为0.0001,batch_size为50的结果:
学习率为0.001,batch_size为250的结果:
可以明显看出随着batch_size的增大、loss的减小,效果明显改善。
(个人猜测:增大batch_size的值后,能够一次性处理更多的数据,从而能够更好地把握大方向,训练的波动程度更小)
(二)CGAN(条件生成-对抗网络)
- 对比于GAN,CGAN在生成器以及判别器上都多了一个标签作为输入
- 生成器的输入是噪声和标签,输出是生成图
- 判别器的输入是生成图,真实图以及标签,输出是真和假
步骤与GAN相似,不同的是在生成器和判别器的定义中加入了10维的标签信息
全连接判别器:
全连接生成器:
epoch改为100后:
在epoch为100时,辨别器的损失为0.00030,效果不太好
(三)DCGAN(深度卷积对抗网络)
- 对比于GAN,在判别器和生成器中使用了卷积结构(在第二个、第三个、第四个滑动卷积层中使用BN加快网络收敛),同样添加Sigmoid激活函数
滑动卷积判别器:
反滑动卷积生成器:
- 第一层:把输入线性变换成256×4×4的矩阵,并在这个基础上做反卷积
- 第四层:不使用BN,使用tanh激活函数
epoch为30时,结果如下:
在epoch改为100后,效果不如epoch为30的结果(不想明白什么原因)