GAN学习整理
最近看了一些关于生成对抗网络的学习,记录以便翻查
GAN
1.原始GAN
Ian J. Goodfellow等人于2014年10月在Generative Adversarial Network中提出了一个通过对抗过程估计生成模型的新框架,下文会简称GAN。文中提到将会有两个模块,第一个是生成模块G,另一个是判别模块D,可以简单地看作是两个网络相互博弈的过程。如果用图像为例子,判别模块的大致结构会如下图所示。
从图中我们可以看到判别模块会有来自两个地方的输入,第一个来自一个真实图片,另一个是来自生成模块的生成的假图。对于判别模块来说,真实图片的标签都是1,而生成的假图的标签都为0,因此显然这是一个不需要数据集有标识的学习方法,所以GAN是一个无监督的学习方法。在原始的GAN中判别去会利用sigmoid函数对输入的数据运算出一个概率,去判断它是否是真是的照片。利用逻辑交叉熵作为损失函数,Adam作为优化方法。经过不断地迭代运算,判别模块就可以学习到区分真假照片的能力。另一方面,在图的右边是一个生成器,下图将展示生成器的具体结构。
如图所示,生成器首先接受一个自定义的噪声,经过运算,这个运算可以根据实际情况而定,如果是图片,可以是一个卷积的神经网络,将噪声转化成一张生成的假图,然后输入到判别模块中进行判断。这时候我们要另生成模块生成图向真实靠拢,所以我们将标签设置为1,之后利用逻辑交叉熵作为损失函数和Adam作为优化器进行学习。这样生成器就会学习到生成一张近似真实图片的能力。
然后生成模块和判断模块进行交替的迭代训练,一方面判断模块在不断地完善自己的判断能力,希望能够区分真生成的假图,另一方面生成模块不断增加自己的生成能力,希望能够蒙骗判断模块。因此两者产生对抗,在不断的迭代对抗中相互进步提高,最终我们的目标是得到一个比较好的生成模块,能够产生一张近似真实的图片。在论文中提到,可以设置两者训练的周期,也就是说可以训练两次判断模块再训练一次生成模块,也可以训练三次之后再训练一次。这个值可以由训练的效果和速度来定。
目前为止,我已经介绍完了生成对抗网路的大体框架,整个模型看上去很好,但是在实际应用中,往往会出现训练不稳定,梯度消失等问题。下面这两篇论文就是介绍原始的GAN为什么会出现这样的问题,以及该如何解决。
2.GAN的问题
在经历了几年的发展后,GAN得到了广泛的应用,包括在数据数量偏少的情况下增加数据,通过机器学习去进行创作等方面有很多应用。同时学者也对GAN进行了许多的改动,包括利用卷积网络的DCGA和限制噪声的输入的InfoGAN它能够有效地生产丰富的彩色人像照片。但是,这些算法似乎都很难去彻底解决GAN训练不稳定和损失函数无法指示训练进程等问题。直到在2017年,Wasserstein GAN的出现。以下简称WGAN.在介绍WGAN之前,我们先介绍一篇Wasserstein GAN的前作,它对于WAN的提出起到了很大的作用,它指出了许多GAN中会出现训练问题的原因,这篇文章就是Martin Arjovsky等人在2017年发表的Towards Principled Methods for Training Generative Adversarial Networks。这篇文章从数学的角度解释了为什么GAN会出现训练问题。下面我简要地分析总结一下。
这个说明如果Pr(x)=0且Pg(x)!=0,最优判别模块就应该非常自信地给出概率0;如果Pr(x)=Pg(x),说明该样本是真是假的可能性刚好为一半,此时最优判别模块也应该给出概率0.5。然后经过一系列变化,最后作者得到了一个变换后的生成函数损失函数:
从公式中可以看出,判别模块的损失函数其实和真实图片分布和生成的假图分布的JS散度相关。降低损失函数相当于降低两个分布的JS散度,判别模块越好,两者越接近。但是问题就出在这个JS散度上,事实上真实分布和生成分布往往差距很大,他们之间并没有重叠的部分,JS散度很可能一直为log2,导致损失函数的梯度一直为0,因此出现了梯度消失的情况,难以训练。因此,判别模块越好,生成器梯度消失越严重。
3.WGAN
针对上篇论文所提出的问题,Martin Arjovsky又在2017年发布了第二篇论文Wasserstein GAN,这篇论文提出了一个Wasserstein距离的概念去解决上述存在的问题。Wasserstein距离又叫Earth-Mover(EM)距离,定义如下:
解释如下:Π(Pr(x),Pg(x))是Pr(x)和Pg(x)组合起来的所有可能的联合分布的集合,反过来说,Π(Pr(x),Pg(x))中每一个分布的边缘分布都是Pr(x)和Pg(x)。对于每一个可能的联合分布γ而言,可以从(x,y)~γ中采样得到一个真实样本和一个生成样本,并算出这对样本的距离x-y,所以可以计算该联合分布下样本对距离的期望值E(x,y)~γ(x-y)。在所有可能的联合分布中能够对这个期望值取到的下界,就定义为Wasserstein距离。
Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。作者利用Wasserstein重新对GAN网络进行了修改,定义出了一下的损失函数: WGAN生成器损失函数
其中上面的是就是判别模块形成的函数,为生成器函数。总结起来,WAGN主要对GAN做出了以下改进:
- 判别模块最后一层去掉sigmoid
- 生成器和判别模块的loss不取log
- 每次更新判别模块的参数之后把它们的绝对值截断到不超过一个固定常数
- 不要用基于动量的优化算法(包括Adam),推荐RMSProp或者SGD。