GAN在过去几年里已成为深度学习中最热门的子领域之一,Yann LeCun说GAN是过去10年机器学习最有趣的想法。
看完后,你应该对:
GAN是什么
具体要做一个简单的GAN应该怎么做
GAN能做啥
都很清楚了!
目录:
GAN简介(与图灵学习和纳什均衡的关系)
使用“垃圾邮件识别“进行详细说明(定义混淆矩阵,双方博弈流程)
GAN的应用(较为轻松)
GAN的Keras简单实现,使理解更清晰(到这里请认真 一些,严肃脸)
GAN最经常看到的例子就是斑马和马的互相转换了,相信你即使不知道GAN是什么,也曾见过这个例子。
<img src="https://pic3.zhimg.com/50/v2-fa3b28c8e2b44ecbb0ce8e2c7571e68a_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2136" data-rawheight="1958" data-default-watermark-src="https://pic3.zhimg.com/50/v2-d94fbe9ba69e9fe1ed5612021be8a089_hd.jpg" class="origin_image zh-lightbox-thumb" width="2136" data-original="https://pic3.zhimg.com/v2-fa3b28c8e2b44ecbb0ce8e2c7571e68a_r.jpg"/>
GAN简介
GAN的想法非常巧妙,它会创建两个不同的对立的网络,目的是让一个网络生成与训练集不同的且足以让另外一个网络难辨真假的样本。
“图灵学习”本质上可以对GAN进行概括。相关的“图灵测试”是广为人知的概念,即计算机试图与人对话并让人误以为它也是一个正常人类。“图灵测试”类似于GAN中generator(生成器)的目标,试图欺骗的是对应的‘adversary’--- discriminator(鉴别器)。
GAN可以用任何形式的generator和discriminator,不一定非得使用神经网络。而神经网络被广泛使用的主要原因是它一种通用函数逼近算法(universal function approximator),即我们能够使用大量节点的神经网络来模拟任何非线性的Input与Output之间的函数,相对其他方法具有更高的自由度,不会因为算法本身的能力而受限。对于generator或discriminator没有任何形式的限制,两者的形式也不必要相同。
这里我们先把generator和discriminator看做两个黑盒,里面包着全能的神经网络。generator(G)的输入是noise z,输出是生成的样本G(z)。然后将生成的样本混合真实数据输入discriminator(D),discriminator进行二分类并给出一个是否为真的打分D(G(z))。
<img src="https://pic1.zhimg.com/50/v2-a9c44141e295c72ba1d1c499e3fa72c7_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1212" data-default-watermark-src="https://pic3.zhimg.com/50/v2-c3e41b517e90391b0e4678acbe8ba9f4_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic1.zhimg.com/v2-a9c44141e295c72ba1d1c499e3fa72c7_r.jpg"/>
generator和discriminator的loss很大程度依赖于discriminator的好坏。G要maximizeD(G(z)), D要maximizeD(x),minimize D(G(z))。GAN的整个想法都以博弈论为基础,generator和discriminator相互对抗,最终相对于另一网络自己都处于峰值,达到纳什均衡。
<img src="https://pic4.zhimg.com/50/v2-017b32a744852858efc416d35a68fc95_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1212" data-default-watermark-src="https://pic2.zhimg.com/50/v2-0bf6b657d6f9d16fc3c5455a21c2f5c0_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-017b32a744852858efc416d35a68fc95_r.jpg"/>
可见下图描绘的是G和D在每一个epoch的Loss Function,最终平的部分即达到了纳什均衡。
<img src="https://pic3.zhimg.com/50/v2-7aebeb2386c8892020cb9bf0c215dadd_hd.jpg" data-caption="" data-size="normal" data-rawwidth="1844" data-rawheight="1080" data-default-watermark-src="https://pic4.zhimg.com/50/v2-063952aec2d856b47764f6f046cae4ae_hd.jpg" class="origin_image zh-lightbox-thumb" width="1844" data-original="https://pic3.zhimg.com/v2-7aebeb2386c8892020cb9bf0c215dadd_r.jpg"/>
垃圾邮件的例子
举个邮件分类的例子来进行说明,假设有一个叫Gary的营销人员试图骗过David的垃圾邮件分类器来发送垃圾邮件。Gary希望能尽可能地发送多的垃圾邮件,David希望尽可能少的垃圾邮件通过。理想情况下会达到纳什均衡,尽管我们谁都不想收到垃圾邮件。
<img src="https://pic3.zhimg.com/50/v2-055600162d8fcb50d8921dd8f2cd5d50_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1432" data-default-watermark-src="https://pic4.zhimg.com/50/v2-9271d4c5fb90270e1d880b2904523098_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic3.zhimg.com/v2-055600162d8fcb50d8921dd8f2cd5d50_r.jpg"/>
在收到邮件后,David可以查看spam filter的效果并通过”误报”或”漏报”来惩罚spam filter。
<img src="https://pic3.zhimg.com/50/v2-016a021afebc39faa793449213abd404_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1580" data-default-watermark-src="https://pic1.zhimg.com/50/v2-5ce5f37bc92b3cddd9f866679f2e990a_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic3.zhimg.com/v2-016a021afebc39faa793449213abd404_r.jpg"/>
假设Gary通过自己发送给自己可以验证他的垃圾邮件哪些通过了,那么Gary和David就可以通过混淆矩阵(confusion matrix,名字听起来高大上,其实就是个表格而已)来评价自己的工作做的如何:
<img src="https://pic3.zhimg.com/50/v2-2948038884a3944ce758d369b188f0ab_hd.jpg" data-caption="" data-size="normal" data-rawwidth="353" data-rawheight="181" data-default-watermark-src="https://pic3.zhimg.com/50/v2-e39721bd7c113455125bd5353427f67f_hd.jpg" class="content_image" width="353"/>
下面是Gary和David得到的混淆矩阵:
<img src="https://pic3.zhimg.com/50/v2-5eda0bdee4a33a84286a2da8439f0e7f_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1580" data-default-watermark-src="https://pic3.zhimg.com/50/v2-91242d9b3bdb9c2918a3c23fdd53bf52_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic3.zhimg.com/v2-5eda0bdee4a33a84286a2da8439f0e7f_r.jpg"/>
经过此之后,Gary和David都知道出了什么问题,并从错误中学习。Gary会基于之前的成功经验尝试其他的方法来生成更好的垃圾邮件。David会看一下spam filter哪里出错了并改进过滤机制。
<img src="https://pic4.zhimg.com/50/v2-20019365131cc592d0ad9a799c4f9272_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1340" data-default-watermark-src="https://pic3.zhimg.com/50/v2-307918bf25c3edf203e498c43fc2252d_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-20019365131cc592d0ad9a799c4f9272_r.jpg"/>
然后不断地重复这个过程,直到达到某种纳什均衡(当然,有可能最终导致模型崩溃,因为某一方找到了完美的伪装方法或者分辨垃圾邮件的方法)。
下面来详细看一下混淆矩阵的四个象限。
1.True Positive :邮件是Gary生成的垃圾邮件并且被David判定为垃圾邮件。 generator:被抓包,工作做的不够好,需要优化。 discriminator:当前不需要做什么。
<img src="https://pic4.zhimg.com/50/v2-5d14842b04848d3b1563e8ed28bfe971_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1154" data-default-watermark-src="https://pic2.zhimg.com/50/v2-e276d2f080432b74486eac6285bbebf6_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-5d14842b04848d3b1563e8ed28bfe971_r.jpg"/>
2.False Negative :邮件不是垃圾邮件,但是被David判定为垃圾邮件。 generator:当前不需要做什么。 discriminator:工作做的不够好,需要优化。 
<img src="https://pic2.zhimg.com/50/v2-b5996af44d36b0397b053f5587a8c4fc_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1154" data-default-watermark-src="https://pic1.zhimg.com/50/v2-c6461e7b34630c383895807a2403b7dc_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic2.zhimg.com/v2-b5996af44d36b0397b053f5587a8c4fc_r.jpg"/>
3.False Positive :邮件是垃圾邮件,但是被David判定为正常邮件。 generator:当前不需要做什么 discriminator:工作做的不够好,需要优化。
<img src="https://pic4.zhimg.com/50/v2-60447bf80a75b14b58828be046d583e5_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1154" data-default-watermark-src="https://pic1.zhimg.com/50/v2-1c8a0681c5cfdb670c1b499deebc3b05_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-60447bf80a75b14b58828be046d583e5_r.jpg"/>
4.True Negative :邮件不是垃圾邮件,David也判定是正常邮件。 generator:当前不需要做什么 discriminator:当前不需要做什么
<img src="https://pic3.zhimg.com/50/v2-01e28085ddabfd250f479a8804a64945_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1084" data-default-watermark-src="https://pic4.zhimg.com/50/v2-1076b16761447a5047ccfeab66d07f3b_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic3.zhimg.com/v2-01e28085ddabfd250f479a8804a64945_r.jpg"/>
基于上面讨论,图示Network如何训练的:
<img src="https://pic4.zhimg.com/50/v2-017b32a744852858efc416d35a68fc95_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1212" data-default-watermark-src="https://pic2.zhimg.com/50/v2-0bf6b657d6f9d16fc3c5455a21c2f5c0_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-017b32a744852858efc416d35a68fc95_r.jpg"/>
训练的步骤包括: 1.取batch的训练集x,和随机生成noise z; 2.计算loss; 3.使用back propagation更新generator和discriminator;
我们已经分析好了,在True Positive,False Negative,False Positive情况下需要更新:
True Positive :意味着generator生成的fake数据被抓包,需要对generator进行优化。需要经过参数被固定的discriminator计算loss,更新generator的权重。注意一次只能对两个网络中的一个进行参数调整。
<img src="https://pic1.zhimg.com/50/v2-496bfe52e0b04225cd81098b9511e7dc_hd.jpg" data-caption="" data-size="small" data-rawwidth="1652" data-rawheight="1928" data-default-watermark-src="https://pic2.zhimg.com/50/v2-53bd54fe9eeb8531f5050c6acbeedf55_hd.jpg" class="origin_image zh-lightbox-thumb" width="1652" data-original="https://pic1.zhimg.com/v2-496bfe52e0b04225cd81098b9511e7dc_r.jpg"/>
False Negative :意味着真的训练集被discriminator错认为fake数据。只更新discriminator的权重。
<img src="https://pic1.zhimg.com/50/v2-9e4cf2f5371425b7f4d682b02a4a4872_hd.jpg" data-caption="" data-size="small" data-rawwidth="1604" data-rawheight="1768" data-default-watermark-src="https://pic2.zhimg.com/50/v2-f82e27b37e304d7def47ccc67e41ca8a_hd.jpg" class="origin_image zh-lightbox-thumb" width="1604" data-original="https://pic1.zhimg.com/v2-9e4cf2f5371425b7f4d682b02a4a4872_r.jpg"/>
False Positive :generator生成的fake数据,被discriminator判定为真的训练集。只对discriminator进行更新。
<img src="https://pic2.zhimg.com/50/v2-08bab95fc9e4e7e66a9a90a837d51f24_hd.jpg" data-caption="" data-size="small" data-rawwidth="1652" data-rawheight="1928" data-default-watermark-src="https://pic4.zhimg.com/50/v2-05afdc525c9837cc25e1d5af53893b68_hd.jpg" class="origin_image zh-lightbox-thumb" width="1652" data-original="https://pic2.zhimg.com/v2-08bab95fc9e4e7e66a9a90a837d51f24_r.jpg"/>
现在让我们用更数学 的角度来解释一下:
我们有一个已知的real的分布,generator生成了一个fake的分布。因为这个两个分布不完全相同,所以他们之间存在KL-divergence,也就是损失函数不为0。
<img src="https://pic2.zhimg.com/50/v2-be776a313a288b2247cc94821872e6f1_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1394" data-default-watermark-src="https://pic4.zhimg.com/50/v2-dda0ba78e7459632d370a6e04997cf6e_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic2.zhimg.com/v2-be776a313a288b2247cc94821872e6f1_r.jpg"/>
discriminator同时看到real的分布和fake的分布。如果discriminator能分清楚来自generator生成的与来自real分布的,就会生成loss并反向传播更新generator的权重。
<img src="https://pic3.zhimg.com/50/v2-eec59b83f1a117b9e91118233044ce23_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1920" data-default-watermark-src="https://pic3.zhimg.com/50/v2-5782d8c5d1109ffc53b4cacd2db6a319_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic3.zhimg.com/v2-eec59b83f1a117b9e91118233044ce23_r.jpg"/>
generator更新完成后,生成的fake数据更符合real的分布。
<img src="https://pic3.zhimg.com/50/v2-dd1789a1be1176413c6f0e3da18a7e27_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2528" data-rawheight="1968" data-default-watermark-src="https://pic1.zhimg.com/50/v2-baf1698dbf08ea43912d8bd51892d823_hd.jpg" class="origin_image zh-lightbox-thumb" width="2528" data-original="https://pic3.zhimg.com/v2-dd1789a1be1176413c6f0e3da18a7e27_r.jpg"/>
但是如果生成的data仍然不够接近real的分布,discriminator依然能识别出来了,因此再次对generator进行权重更新。
<img src="https://pic1.zhimg.com/50/v2-4a390d25de9e063f2aeef4b221192cd9_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1950" data-default-watermark-src="https://pic2.zhimg.com/50/v2-681932f0715922dc934a492132a842b3_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic1.zhimg.com/v2-4a390d25de9e063f2aeef4b221192cd9_r.jpg"/>
终于这次discriminator被骗过了,它认为generator生成的fake数据就是符合real分布的。这个就对应False Positive的情况,需要对discriminator进行更新。
<img src="https://pic2.zhimg.com/50/v2-9a7bd71afcfb711d5b8761d46f06f7f1_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2548" data-rawheight="1948" data-default-watermark-src="https://pic4.zhimg.com/50/v2-8a4f13ae51ecfb95f02d6d98a46cb1d8_hd.jpg" class="origin_image zh-lightbox-thumb" width="2548" data-original="https://pic2.zhimg.com/v2-9a7bd71afcfb711d5b8761d46f06f7f1_r.jpg"/>
Loss反向传播来更新discriminator的权重。
<img src="https://pic4.zhimg.com/50/v2-b3cd2c4eadb1184a255dd65c61addf08_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1740" data-default-watermark-src="https://pic2.zhimg.com/50/v2-e9cdd3ea15e5e52500b1b670cb24852e_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-b3cd2c4eadb1184a255dd65c61addf08_r.jpg"/>
继续这个过程,直到generator生成的分布与real分布无法区分时,网络达到纳什均衡。
<img src="https://pic3.zhimg.com/50/v2-1e8b3bba8ed7a8382fe24ca3de870136_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1904" data-default-watermark-src="https://pic2.zhimg.com/50/v2-bcbf1edead93421fc72e301e81a26e73_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic3.zhimg.com/v2-1e8b3bba8ed7a8382fe24ca3de870136_r.jpg"/>
GAN的实际应用
(Conditional) Synthesis—条件生成
最好玩的比如Text2Image、Image2Text。可以基于一段文字生成一张图片,比如这个Multi-Condition GAN(MA-GAN)的text-to-image的例子:
<img src="https://pic2.zhimg.com/50/v2-6cf1a9ebd63b815858f60ed0b04b63e0_hd.jpg" data-size="normal" data-rawwidth="2800" data-rawheight="1446" data-default-watermark-src="https://pic2.zhimg.com/50/v2-fc6269e1f03cea2d62ba3d0102e8c5c7_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic2.zhimg.com/v2-6cf1a9ebd63b815858f60ed0b04b63e0_r.jpg"/> 详见 https://arxiv.org/pdf/1902.06068.pdf
Data Augmentation—数据增强 GAN学习训练集样本的分布,然后进行采样生成新的样本,我们可以使用这些样本来增强训练集。一般我们都是通过对原训练集的图片进行旋转和扭曲来进行增强,这里GAN提供了一种新的方法。
Style Transfer和Manipulation-风格转换 将一张图片的style转移到另外一张图像上,这与neural style transfer非常类似。Neural Style Transfer可以认为是把Style Image的风格加入到Content Image里。因为只有一张Style Image,所以它其实学到的很难完全是Style的特征,因为一个画家的风格很难通过一幅作品就展现出来。GAN能够很好的从多个作品中学习到画家的真正风格特征。
第2/3列为neural style transfer的效果,第5列为cycleGAN:
<img src="https://pic3.zhimg.com/50/v2-16ca621f08c817b2a1500c2f97143fee_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2930" data-rawheight="1824" data-default-watermark-src="https://pic4.zhimg.com/50/v2-8a00eb64a5b32c5754bc90e7dfcf4403_hd.jpg" class="origin_image zh-lightbox-thumb" width="2930" data-original="https://pic3.zhimg.com/v2-16ca621f08c817b2a1500c2f97143fee_r.jpg"/>
可以看出对背景特别有效,比如对云的转换等:
<img src="https://pic1.zhimg.com/50/v2-a9a13947d0eff422fa65ba428f623294_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1662" data-default-watermark-src="https://pic4.zhimg.com/50/v2-0230b8d210e28d03c28f63f56fcdc480_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic1.zhimg.com/v2-a9a13947d0eff422fa65ba428f623294_r.jpg"/>
GAN在动物和水果上的效果:
<img src="https://pic3.zhimg.com/50/v2-7f9faf67e56619237b9876d562dd1972_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2444" data-rawheight="1708" data-default-watermark-src="https://pic4.zhimg.com/50/v2-6d8c67f517c7eacfe89c95fbb9a63c40_hd.jpg" class="origin_image zh-lightbox-thumb" width="2444" data-original="https://pic3.zhimg.com/v2-7f9faf67e56619237b9876d562dd1972_r.jpg"/>
四季变换:
<img src="https://pic2.zhimg.com/50/v2-8bf73ccf96b46d4bbcad41f7172dc08c_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1096" data-default-watermark-src="https://pic1.zhimg.com/50/v2-4800be7ad0f4d17a578aad43cf523962_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic2.zhimg.com/v2-8bf73ccf96b46d4bbcad41f7172dc08c_r.jpg"/>
改变照片的景深:
<img src="https://pic4.zhimg.com/50/v2-d642f43dad788093aa67d4c93a0027bd_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1306" data-default-watermark-src="https://pic1.zhimg.com/50/v2-cca0a7ba51dfc2a50ae5fc2232aa0df5_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-d642f43dad788093aa67d4c93a0027bd_r.jpg"/>
对线稿填充不变成真实的物体:
<img src="https://pic4.zhimg.com/50/v2-efc05538ac3c11613281be3ec1f5c1ac_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1250" data-default-watermark-src="https://pic2.zhimg.com/50/v2-7a7ebe5d6a720d1989eb040f37ad282d_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-efc05538ac3c11613281be3ec1f5c1ac_r.jpg"/>
可以利用风格转换来渲染图像,变成游戏GTA风格的:
<img src="https://pic2.zhimg.com/50/v2-adb63c6d91f89b5781f5c6a29c556f60_hd.jpg" data-caption="" data-size="normal" data-rawwidth="1532" data-rawheight="696" data-default-watermark-src="https://pic1.zhimg.com/50/v2-a57b7942ba28630f95b5435450e316c3_hd.jpg" class="origin_image zh-lightbox-thumb" width="1532" data-original="https://pic2.zhimg.com/v2-adb63c6d91f89b5781f5c6a29c556f60_r.jpg"/>
将白天变夜晚:
<img src="https://pic3.zhimg.com/50/v2-b56207be4bbe351f69388fb4d058d8e8_hd.jpg" data-caption="" data-size="normal" data-rawwidth="1536" data-rawheight="690" data-default-watermark-src="https://pic1.zhimg.com/50/v2-be1a1f96a42d1011507008c6f9b3549d_hd.jpg" class="origin_image zh-lightbox-thumb" width="1536" data-original="https://pic3.zhimg.com/v2-b56207be4bbe351f69388fb4d058d8e8_r.jpg"/>
style transfer可以具体见这个survey。 https://arxiv.org/pdf/1902.06068.pdf
Image Super-Resolution 即将图像从低分辨率LR恢复到高分辨率HR。
<img src="https://pic4.zhimg.com/50/v2-6563a67cb6fd50ea7b0203c83f9a3130_hd.jpg" data-caption="" data-size="normal" data-rawwidth="2800" data-rawheight="1160" data-default-watermark-src="https://pic1.zhimg.com/50/v2-c927ae620853530e046252df7b8aa7fa_hd.jpg" class="origin_image zh-lightbox-thumb" width="2800" data-original="https://pic4.zhimg.com/v2-6563a67cb6fd50ea7b0203c83f9a3130_r.jpg"/>
GAN的简单实现
下面是一个最简化的使用Keras实现的GAN,基于CelebA数据集。
model的定义:
定义discriminator,然后compile;
定义generator,不进行compile;
定义一个model包含generator和discriminator,把discriminator设为not trainable,然后compile:
<img src="https://pic3.zhimg.com/50/v2-d59f4f7fbee6d8182af94d79dad068b4_hd.jpg" data-caption="" data-size="small" data-rawwidth="778" data-rawheight="390" data-default-watermark-src="https://pic1.zhimg.com/50/v2-aca9b8ec9fabe43e14edc76dfb04d0b4_hd.jpg" class="origin_image zh-lightbox-thumb" width="778" data-original="https://pic3.zhimg.com/v2-d59f4f7fbee6d8182af94d79dad068b4_r.jpg"/>
训练loop:
从训练集选择R张图像;
采样大小为N的随机噪声,输入generator产生F张fake的图像;
将R张训练集与F张fake图像和对应的label输入discriminator进行训练;
采样大小为N的随机噪声;
用train_on_batch,以目标label为1对generator进行训练更新。
Imports:
import keras
from keras.layers import *
from keras.datasets import cifar10
import glob, cv2, os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import clear_output
Global Parameters:
优先在前面声明所有变量,比较清晰
SPATIAL_DIM = 64 # Spatial dimensions of the images.
LATENT_DIM = 100 # Dimensionality of the noise vector.
BATCH_SIZE = 32 # Batchsize to use for training.
DISC_UPDATES = 1 # Number of discriminator updates per training iteration.
GEN_UPDATES = 1 # Nmber of generator updates per training iteration.
FILTER_SIZE = 5 # Filter size to be applied throughout all convolutional layers.
NUM_LOAD = 10000 # Number of images to load from CelebA. Fit also according to the available memory on your machine.
NET_CAPACITY = 16 # General factor to globally change the number of convolutional filters.
PROGRESS_INTERVAL = 80 # Number of iterations after which current samples will be plotted.
ROOT_DIR = 'visualization' # Directory where generated samples should be saved to.
if not os.path.isdir(ROOT_DIR):
os.mkdir(ROOT_DIR)
数据预处理
对所有训练的image做normalize处理
def plot_image(x):
plt.imshow(x * 0.5 + 0.5)
X = []
# Reference to CelebA dataset here. I recommend downloading from the Harvard 2019 ComputeFest GitHub page (there is also some good coding tutorials here)
faces = glob.glob('../Harvard/ComputeFest 2019/celeba/img_align_celeba/*.jpg')
for i, f in enumerate(faces):
img = cv2.imread(f)
img = cv2.resize(img, (SPATIAL_DIM, SPATIAL_DIM))
img = np.flip(img, axis=2)
img = img.astype(np.float32) / 127.5 - 1.0
X.append(img)
if i >= NUM_LOAD - 1:
break
X = np.array(X)
plot_image(X[4])
X.shape, X.min(), X.max()
<img src="https://pic2.zhimg.com/50/v2-314ee8a798459ba3697b9477495208fe_hd.jpg" data-caption="" data-size="normal" data-rawwidth="508" data-rawheight="504" data-default-watermark-src="https://pic1.zhimg.com/50/v2-051c7cd1c709cb88c6eb8184910af350_hd.jpg" class="origin_image zh-lightbox-thumb" width="508" data-original="https://pic2.zhimg.com/v2-314ee8a798459ba3697b9477495208fe_r.jpg"/>
定义架构
将block抽象单独定义出来能让代码更简洁。padding方法选择"same":
def add_encoder_block(x, filters, filter_size):
x = Conv2D(filters, filter_size, padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(filters, filter_size, padding='same', strides=2)(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.3)(x)
return x
使用encoder_block定义discriminator,循环使用encoder block,并不断增大filter size:
def build_discriminator(start_filters, spatial_dim, filter_size):
inp = Input(shape=(spatial_dim, spatial_dim, 3))
# Encoding blocks downsample the image.
x = add_encoder_block(inp, start_filters, filter_size)
x = add_encoder_block(x, start_filters * 2, filter_size)
x = add_encoder_block(x, start_filters * 4, filter_size)
x = add_encoder_block(x, start_filters * 8, filter_size)
x = GlobalAveragePooling2D()(x)
x = Dense(1, activation='sigmoid')(x)
return keras.Model(inputs=inp, outputs=x)
下面定义decoder block,主要做反卷积:
def add_decoder_block(x, filters, filter_size):
x = Deconvolution2D(filters, filter_size, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.3)(x)
return x
下面定义generator,循环使用decoder block并不断降低filter size:
def build_generator(start_filters, filter_size, latent_dim):
inp = Input(shape=(latent_dim,))
# Projection.
x = Dense(4 * 4 * (start_filters * 8), input_dim=latent_dim)(inp)
x = BatchNormalization()(x)
x = Reshape(target_shape=(4, 4, start_filters * 8))(x)
# Decoding blocks upsample the image.
x = add_decoder_block(x, start_filters * 4, filter_size)
x = add_decoder_block(x, start_filters * 2, filter_size)
x = add_decoder_block(x, start_filters, filter_size)
x = add_decoder_block(x, start_filters, filter_size)
x = Conv2D(3, kernel_size=5, padding='same', activation='tanh')(x)
return keras.Model(inputs=inp, outputs=x)
训练
构建网络和训练流程 构建GAN时,定义discriminator为not trainable是非常重要的。因为我们不能同时训练两个网络,就像我们不能校正多个同时变化的东西。所以在训练某一个网络时,需要保持其他部分是固定不变的。
def construct_models(verbose=False):
# 1. Build discriminator.
discriminator = build_discriminator(NET_CAPACITY, SPATIAL_DIM, FILTER_SIZE)
discriminator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(lr=0.0002), metrics=['mae'])
# 2. Build generator.
generator = build_generator(NET_CAPACITY, FILTER_SIZE, LATENT_DIM)
# 3. Build full GAN setup by stacking generator and discriminator.
gan = keras.Sequential()
gan.add(generator)
gan.add(discriminator)
::discriminator.trainable = False:: # Fix the discriminator part in the full setup.
gan.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(lr=0.0002), metrics=['mae'])
if verbose: # Print model summaries for debugging purposes.
generator.summary()
discriminator.summary()
gan.summary()
return generator, discriminator, gan
模型训练:
def run_training(start_it=0,num_epochs=1000):
# Save configuration file with global parameters
config_name = 'gan_cap' + str(NET_CAPACITY) + '_batch' + str(BATCH_SIZE) + '_filt' + str(FILTER_SIZE) + '_disc' + str(DISC_UPDATES) + '_gen' + str(GEN_UPDATES)
folder = os.path.join(ROOT_DIR, config_name)
if not os.path.isdir(folder):
os.mkdir(folder)
# Initiate loop variables
avg_loss_discriminator = []
avg_loss_generator = []
total_it = start_it
# Start of training loop
for epoch in range(num_epochs):
loss_discriminator = []
loss_generator = []
for it in range(200):
# Update discriminator.
for i in range(DISC_UPDATES):
# Fetch real examples (you could sample unique entries, too).
imgs_real = X[np.random.randint(0, X.shape[0], size=BATCH_SIZE)]
# Generate fake examples.
noise = np.random.randn(BATCH_SIZE, LATENT_DIM)
imgs_fake = generator.predict(noise)
d_loss_real = ::discriminator.train_on_batch::(imgs_real, np.ones([BATCH_SIZE]))[1]
d_loss_fake = ::discriminator.train_on_batch::(imgs_fake, np.zeros([BATCH_SIZE]))[1]
# Progress visualizations.
if total_it % PROGRESS_INTERVAL == 0:
plt.figure(figsize=(5,2))
# We sample separate images.
num_vis = min(BATCH_SIZE, 8)
imgs_real = X[np.random.randint(0, X.shape[0], size=num_vis)]
noise