Self-Knowledge Distillation: A Simple Way for Better Generalization论文阅读
这是一篇2020年6月挂载arxiv上的论文,目前还处于under review。这篇论文从正则化方法可以提高网络的泛化能力这一点出发,提出了一个简单且有效的正则化方法——
Self-KD
,即自知识蒸馏。从作者所做的一系列实验可以看出来,Self-KD不仅在提高精度上面起到了作用,还能提供高质量的置信度估计(Expected Calibration Error,ECE)。
1. Introdution
本论文从正则化方法能够提升模型泛化能力这个点切入,将正则化方法归为以下几类:
- 限制函数空间的正则化方法:L1、L2 weight decay
- 在训练中加入随机值的正则化方法:Dropout
- 通过规范化层与层之间内部激活值的正则化方法:BatchNorm
- 数据增强的正则化方法:Cutout、Mixup、AugMix, etc.
- 调整targets的正则化方法:label smoothing
接着,本论文继续探讨了label smoothing这方法存在的不足之处,即它无法和现在的先进正则化方法形成互补。所以本篇文章的motivation就出来了:如何找到一个更有效地策略去软化hard targets,从而得到包含更多信息的labels?
2. Self-Knowledge Distillation
作者先从最原始的Knowledge Distillation(知识蒸馏)出发,公式如下:
p
~
i
=
e
x
p
(
z
i
(
x
)
/
τ
)
∑
j
e
x
p
(
z
j
(
x
)
/
τ
)
(1)
widetilde{p}_i = frac{exp(z_i(x)/ au)}{sum_j exp(z_j(x)/ au)} ag{1}
p
i=∑jexp(zj(x)/τ)exp(zi(x)/τ)(1)
L K D ( x , y ) = ( 1 − α ) H ( y , P S ( x ) ) + α τ 2 H ( P ~ T ( x ; τ ) , P ~ S ( x ; τ ) ) (2) L_{KD}(x, y) = (1-alpha) H(y, P^S(x)) + alpha au^2 H(widetilde{P}^T(x; au), widetilde{P}^S(x; au)) ag{2} LKD(x,y)=(1−α)H(y,PS(x))+ατ2H(P T(x;τ),P S(x;τ))(2)
其中 τ au τ为温度变量, z ( x ) z(x) z(x)为logit vector, P ~ T ( x ; τ ) widetilde{P}^T(x; au) P T(x;τ)为老师网络的蒸馏输出, P ~ S ( x ; τ ) widetilde{P}^S(x; au) P S(x;τ)为学生网络的蒸馏输出, H ( . , . ) H(. , .) H(.,.)是交叉熵损失。
当 τ = 1 au=1 τ=1的时候知识蒸馏的损失可以化简如下:
L K D ( x , y ) = H ( ( 1 − α ) y + α P T ( x ) , P S ( x ) ) (3) L_{KD}(x, y) = H( (1-alpha)y+alpha P^T(x), P^S(x)) ag{3} LKD(x,y)=H((1−α)y+αPT(x),PS(x))(3)
在上述公式(3)的基础之上,论文将Knowledge Distillation变为Self Knowledge Distillation,即distill knowledge of itself
。具体做法是论文将上一个epoch的模型作为下一个epoch的老师,公式如下:
L
K
D
,
t
(
x
,
y
)
=
H
(
(
1
−
α
t
)
y
+
α
t
P
i
<
t
T
(
x
)
,
P
t
S
(
x
)
)
(4)
L_{KD, t}(x, y) = H( (1-alpha_t)y+alpha_t P_{i<t}^T(x), P_t^S(x)) ag{4}
LKD,t(x,y)=H((1−αt)y+αtPi<tT(x),PtS(x))(4)
L K D , t ( x , y ) = H ( ( 1 − α t ) y + α t P t − 1 T ( x ) , P t S ( x ) ) (5) L_{KD, t}(x, y) = H((1-alpha_t)y+alpha_t P_{t-1}^T(x), P_t^S(x)) ag{5} LKD,t(x,y)=H((1−αt)y+αtPt−1T(x),PtS(x))(5)
从上述公式可以知道,整个self-KD loss还有一个超参数
α
t
alpha_t
αt需要确定。
α
t
alpha_t
αt是一个决定老师模型的soft targets在蒸馏训练过程中重要程度的超参数。由分析可知,模型在刚刚训练的时候往往是不够稳定的,所以这个时候权重应该比较低,随着训练进程的不断推进,这个参数应该逐步增加。所以本论文采用了最简单却有效的线性增长去动态调节
α
t
alpha_t
αt:
α
t
=
t
T
α
T
alpha_t = frac{t}{T} alpha_T
αt=TtαT
其中, t是当前epochs的次数,T是总的epochs次数,
α
T
alpha_T
αT是最后一个epoch的
α
alpha
α(本论文取的0.7)。
3. Better Accuracy & High Quality of Confidence Estimates
Better Accuracy很好理解,就是作者在classification, object detection, and machine translation这三个领域的一些benchmarks上做了实验,结果表明加入了Self-KD之后比baseline有了明显的精度提升。
Confidence Estimates则是用了Expected Calibration Error
(ECE)这一个指标进行度量。结果表明加入Self-KD之后模型会有更小的ECE,即更高质量的置信度估计。Expected Calibration Error原理分析请移步我的另一篇博客:Expected Calibration Error (ECE) 模型校准原理分析。