Bootstrap Your Own Latent A New Approach to Self-Supervised Learning
Intro
文章提出一种不需要负样本来做自监督学习的方法,提出交替更新假说解释EMA方式更新target network防止collapse的原因,同时用梯度解释online网络和target不同构带来的好处。
Intuitions
传统基于对比学习的自监督方法都是需要构造负样本来防止mode collapse,但本文提出的方法却不需要负样本。试想一下,如果不使用负样本来防止mode collapse,那么一种比较直观的方式是使用一个teacher网络来指导需要训练的网络,假设teacher网络训练的足够好(比如全监督训练),是非collapse的网络,那么理论上训练一个目标网络和teacher网络对同一输入的输出一致,这样就可以了;但是这样做是违背自监督学习的初衷的,自监督学习不应依赖预训练网络,否则就失去了意义。于是一个简单的想法,假如去随机初始化一个teacher网络,并且认为这个teacher网络经过随机初始化是非collapse的网络,由于随机初始化的网络具有一定的图像先验,所以理论上也不会导致student网络学习到collapse的结果,但可能效果会差一点。
基于上述想法,文章做了这样一个实验,拿随机初始化的网络作为teacher网络并固定其权重,用目标网络学习该网络的输出,但这样的结果是虽然防止了collapse,但是特征表示能力不强;但是经过这样训练的student网络效果却比原始的teacher网络效果要好。
那么首先问题就来了,理论上student网络的上限是teacher网络呀,为什么student网络最终效果比teacher网络还要好?
这里的提升其实来自于增广,teacher和student网络的输入并不是完全一致的,而是两次不同的增广。
验证了上述想法,文章自然想去通过在训练过程中同时“训练”teacher网络来达到试student网络得到提升的效果,因为teacher网络不“训练”就能给student网络带来提升,而teacher网络“训练”到最好的情况是全监督的预训练,效果显然是最好的,在这之间如果能够每次提升student网络之后,利用student网络学习到的知识来提升一下teacher网络,那么这样是直观能够提升最终目标网络效果的。
Method
文章利用student网络提升teacher网络的方式是另teacher网络以以student网络EMA的方式更新来实现的,大体框架如下:
过程不难理解,但需要注意的是prediction部分和stop gradient部分,这里teacher网络和student网络backbone结构都是一致的,唯一不同是student网络多了一个prediction部分,那么问题来了,这一结构是否不可或缺,他起到了什么作用?第二点比较好理解,stop gradient部分即不用梯度方式更新teacher网络,而是使用EMA方式,那么第二个问题就是,为什么使用EMA的方式更新teacher网络能够方式collapse?
看看文章对这两个问题的解释:
如果teacher网络的更新过程是gradient descent的,那么显然他也会陷入collpase,但是teacher网络的更新方式是EMA,而非gradient descent,文章假定并没有一种gradient descent的loss能够同时去更新两个网络的参数使得loss最小,那么这样的loss更新方式其实更类似于GAN的方式,交替更新(这里只更新了一步,第一步并没有更新),按照这样的理解,交替过程的第一步,固定target使q达到最优,则有:
这一步其实并没有对参数进行更新,简单的理解是固定backbone,求解一个最优的q使得q与target的误差最小,但并不更新q,相当于得到了当前backbone参数下最优的投影,其实也比较合理,毕竟student网络和target网络有可能差异较大,通过这一层投影是可以拉近两者距离的(通过最小化q的方式)。
这时候更新整体参数( heta)(包括backbone和predictor,因为第一步并没有更新参数),则要求参数( heta)的梯度:
从这俩公式可以看出其实本文的假设是backbone的参数更新和predictor的参数更新可以等价为两个参数更新的EM过程,先求固定backbone参数下predictor q的最优参数,这时候q取得最佳值,然后利用最佳的q去更新backbone和q的参数。但是该过程并没有详细的数学证明。
因此,按照文章假设,该过程对参数的求导等价于对条件分布的方差的求导。需要注意的是,对于任意随机变量X、Y、Z,有(Var(X|Y) geq Var(X|Y,Z)),这里X是target projection,Y是online projection,Z是q作用后的随机变量,引入Z之后会让方差变小,也就是会让上面的梯度变小,所以解释了为什么本文要引入predictor q(也和上面的理解不谋而合,进一步拉近online分布和target分布的距离)。
对于collapse的分布z,有不等式(Var(z_{xi}^{'}|z_ heta) leq Var(z_{xi}^{'}|c)),其中c为常量(collapse分布),梯度下降只会使方差下降,因此能够找到某一参数( heta)使得其不为常量分布,进而避免了collapse。(感觉有一点牵强,因为不是严格小于)。
总结一下,上面两个问题的解释分别是,引入predictor q会进一步拉近online分布和target分布的距离进而使得梯度更小,更容易优化;而使用EMA方式更新target网络作者提出了交替过程假说来解释这样更新的效果,从梯度下降角度解释了常量分布的条件方差一定比非常量分布的条件方差大,从而说明本文的方法可以避免collapse。