在本篇章,我们将专门针对vanilla RNN,也就是所谓的原始RNN这种网络结构进行前向传播介绍和反向梯度推导。更多相关内容请见《神经网络的梯度推导与代码验证》系列介绍。
注意:
- 本系列的关注点主要在反向梯度推导以及代码上的验证,涉及到的前向传播相对而言不会做太详细的介绍。
- 反向梯度求导涉及到矩阵微分和求导的相关知识,请见《神经网络的梯度推导与代码验证》之数学基础篇:矩阵微分与求导,内含手把手教学级的内容。
目录
提醒:
- 后续会反复出现$oldsymbol{delta}^{l}$这个(类)符号,它的定义为$oldsymbol{delta}^{l} = frac{partial l}{partialoldsymbol{z}^{oldsymbol{l}}}$,即loss $l$对$oldsymbol{z}^{oldsymbol{l}}$的导数
- 其中$oldsymbol{z}^{oldsymbol{l}}$表示第$l$层(DNN,CNN,RNN或其他例如max pooling层等)未经过激活函数的输出。
- $oldsymbol{a}^{oldsymbol{l}}$则表示$oldsymbol{z}^{oldsymbol{l}}$经过激活函数后的输出。
这些符号会贯穿整个系列,还请留意。
4.1 vanilla RNN的前向传播
先贴一张vanilla(朴素)RNN的前传示意图。
上图中左边是RNN模型没有按时间展开的图,如果按时间序列展开,则是上图中的右边部分。我们重点观察右边部分的图。这幅图描述了在序列索引号t附近RNN的模型。其中:
- $oldsymbol{x}^{(t)}$代表在序列索引号$t$时训练样本的输入。注意这里的$t$只是代表序列索引,不一定非得具备时间上的含义,例如$oldsymbol{x}^{(t)}$可以是某句子的第$t$个字(的词向量)。
- $oldsymbol{h}^{(t)}$代表在序列索引号$t$时模型的隐藏状态。$oldsymbol{h}^{(t)}$由$oldsymbol{x}^{(t)}$和$oldsymbol{h}^{(t-1)}$共同决定
- $oldsymbol{a}^{(t)}$代表在序列索引号$t$时模型的输出。$oldsymbol{o}^{(t)}$只由模型当前的隐藏状态$oldsymbol{h}^{(t-1)}$决定
- $oldsymbol{L}^{(t)}$代表在序列索引号$t$时模型的损失函数。
- $oldsymbol{y}^{(t)}$代表在序列索引号$t$时训练样本序列的真实输出
- $oldsymbol{U},oldsymbol{W},oldsymbol{V}$三个矩阵式我们模型的线性相关系数,它们在整个vanilla RNN网络中共享的,这点和DNN很不同。也正因为是共享的,它体现了RNN模型的“循环/递归”的核心思想。
4.1.1 RNN前向传播计算公式
有了上面的模型,RNN的前向传播算法就很容易得到了。
对于任意一个序列索引号$t$,我们隐藏状态$oldsymbol{h}^{(t)}$由$oldsymbol{x}^{(t)}$和$oldsymbol{h}^{(t-1)}$共同得到:
$oldsymbol{h}^{(t)} = sigmaleft( oldsymbol{z}^{(t)} ight) = sigmaleft( {oldsymbol{U}oldsymbol{x}^{(t)} + oldsymbol{W}oldsymbol{h}^{(t - 1)} + oldsymbol{b}} ight)$
其中$sigma$为RNN的激活函数,一般为$tanh$。
序列索引号为$t$时,模型的输出$oldsymbol{o}^{(t)}$的表达式也比较简单:
$oldsymbol{o}^{(t)} = oldsymbol{V}oldsymbol{h}^{(t - 1)} + oldsymbol{c}$
在最终在序列索引号时我们的预测输出为:
${hat{oldsymbol{y}}}^{(t)} = sigmaleft( oldsymbol{o}^{(t)} ight)$
对比下列公式:
$oldsymbol{h}^{(t)} = sigmaleft( {oldsymbol{U}oldsymbol{x}^{(t)} + oldsymbol{W}oldsymbol{h}^{(t - 1)} + oldsymbol{b}} ight)$
$oldsymbol{a}^{l} = sigmaleft( {oldsymbol{W}^{l}oldsymbol{a}^{l - 1} + oldsymbol{b}^{l}} ight)$
上面的是vanilla RNN的$oldsymbol{h}^{(t)}$的递推公式,而下面的是DNN中的层间关系的公式。我们可以发现这两组公式在形式上非常接近。如果将$oldsymbol{h}^{(t)}$的这种时间上的展开看成类似于DNN这种层间堆叠的话,可以发现vanilla RNN每一“层”除了有来自上一“层”的输入$oldsymbol{h}^{(t - 1)}$,还有专属于这一层的输入$oldsymbol{x}^{(t)}$,最重要的是,每一“层”的参数$oldsymbol{W}$和$oldsymbol{b}$都是同一组。而DNN则是有专属于那一层的$oldsymbol{W}^{l}$和$oldsymbol{b}^{l}$。
4.2 vanilla RNN的反向梯度推导
RNN反向传播算法的思路和DNN是一样的,即通过梯度下降法一轮轮的迭代,得到合适的RNN模型参数$oldsymbol{U},oldsymbol{W},oldsymbol{V},oldsymbol{b},oldsymbol{c}$。由于我们是基于时间反向传播,所以RNN的反向传播有时也叫做BPTT(back-propagation through time)。当然这里的BPTT和DNN也有很大的不同点,即这里所有的$oldsymbol{U},oldsymbol{W},oldsymbol{V},oldsymbol{b},oldsymbol{c}$在序列的各个位置是共享的,反向传播时我们更新的是相同的参数。
为了简化描述,这里的损失函数我们为交叉熵损失函数,输出的激活函数为softmax函数,隐藏层的激活函数为tanh函数。
如果RNN在序列的每个位置有输出,则最终的损失L为所有时间步$t$的loss之和:
$L = {sumlimits_{t = 1}^{T}L^{(t)}}$
其中,$oldsymbol{V},oldsymbol{c}$的梯度计算比较简单,跟求DNN的BP是一样的。
根据 数学基础篇:矩阵微分与求导 1.8节例子的中间结果,我们可以知道:
$frac{partial L}{partialoldsymbol{c}} = {sumlimits_{t = 1}^{T}frac{partial L^{(t)}}{partialoldsymbol{c}}} = {sumlimits_{t = 1}^{T}{{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}}}$
$frac{partial L}{partialoldsymbol{V}} = {sumlimits_{t = 1}^{T}frac{partial L^{(t)}}{partialoldsymbol{V}}} = {sumlimits_{t = 1}^{T}left( {{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}} ight)}left( oldsymbol{h}^{(t)} ight)^{T}$
接下来的$oldsymbol{U},oldsymbol{W},oldsymbol{b}$的梯度计算就相对复杂了。从RNN的模型可以看出,在反向传播时,某一序列位置$t$的梯度由当前位置的输出对应的梯度和序列索引位置$t+1$时的梯度两部分共同决定。对于$oldsymbol{W}$在某一序列位置$t$的梯度损失需要反向传播一步一步地计算。我们定义序列索引$t$位置的隐藏状态的梯度为:
$oldsymbol{delta}^{(t)} = frac{partial L}{partialoldsymbol{h}^{(t)}}$
如果我们能知道$oldsymbol{delta}^{(t)}$,那么根据$oldsymbol{h}^{(t)} = sigmaleft( oldsymbol{z}^{(t)} ight) = sigmaleft( {oldsymbol{U}oldsymbol{x}^{(t)} + oldsymbol{W}oldsymbol{h}^{(t - 1)} + oldsymbol{b}} ight)$我们就像DNN那样套用标量对矩阵的链式求导法则来进一步得到$oldsymbol{U},oldsymbol{W},oldsymbol{b}$的梯度了。
根据4.1节中的示意图我们可以轻易发现,当$t = T$,则误差只有$left. L^{(T)} ightarrowoldsymbol{h}^{(T)} ight.$这么一条。
所以:
$oldsymbol{delta}^{(T)} = oldsymbol{V}^{T}left( {{hat{oldsymbol{y}}}^{(T)} - oldsymbol{y}^{(T)}} ight)$
而当$t<T$时,$oldsymbol{h}^{(t)}$的误差来源有两条:
1)$left. L^{(t)} ightarrowoldsymbol{h}^{(t)} ight.$
2)$left. oldsymbol{h}^{({t + 1})} ightarrowoldsymbol{h}^{(t)} ight.$
于是我们得到:
$oldsymbol{delta}^{(t)} = frac{partial L^{(t)}}{partialoldsymbol{h}^{(t)}} + left( frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(t + 1)}}$
我们来逐项求解:
首先对于$frac{partial L^{(t)}}{partialoldsymbol{h}^{(t)}}$:
$oldsymbol{delta}^{(t)} = frac{partial L}{partialoldsymbol{h}^{(t)}} = left( frac{partialoldsymbol{o}^{(t)}}{partialoldsymbol{h}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{o}^{(t)}} = oldsymbol{V}^{T}left( {{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}} ight)$
对于$left( frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} ight)^{T}frac{partial L^{({t + 1})}}{partialoldsymbol{h}^{(t + 1)}}$,我们先关注$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$:
因为$oldsymbol{h}^{(t + 1)} = sigmaleft( oldsymbol{z}^{(t)} ight) = sigmaleft( {oldsymbol{U}oldsymbol{x}^{(t + 1)} + oldsymbol{W}oldsymbol{h}^{(t)} + oldsymbol{b}} ight)$
所以有:
$doldsymbol{h}^{(t + 1)} = sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)igodot doldsymbol{z}^{(t)} = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)doldsymbol{z}^{(t)} = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)dleft( {oldsymbol{W}oldsymbol{h}^{(t)}} ight) = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)oldsymbol{W}doldsymbol{h}^{(t)}$
所以有:$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)oldsymbol{W}$
于是:
$oldsymbol{delta}^{(t)} = oldsymbol{V}^{T}left( {{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}} ight) + oldsymbol{W}^{T}diagleft( {sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)} ight)oldsymbol{delta}^{(t + 1)}$
有了$oldsymbol{delta}^{(T)}$以及从$oldsymbol{delta}^{(t + 1)}$到$oldsymbol{delta}^{(t)}$的递推公式,我们可以轻易求出$oldsymbol{U},oldsymbol{W},oldsymbol{b}$的梯度,由于这三组变量在不同的$t$下是公用的,所以由全微分方程可知,这三个变量应当都是在$t$上的某种累加形式。我们定义只在时间步$t$使用的虚拟变量$oldsymbol{U}^{(t)},oldsymbol{W}^{(t)},oldsymbol{b}^{(t)}$,这样就可以用$frac{partial L}{partialoldsymbol{W}^{(t)}}$来表示$oldsymbol{W}$在时间步$t$的时候对梯度的贡献:
$frac{partial L}{partialoldsymbol{W}} = {sumlimits_{t = 1}^{T}frac{partial L}{partialoldsymbol{W}^{(t)}}} = {sumlimits_{t = 1}^{T}{left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{W}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(t)}} =}}{sumlimits_{t = 1}^{T}{diagleft( {sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)} ight)oldsymbol{delta}^{(t)}left( oldsymbol{h}^{(t - 1)} ight)^{T}}}$
同理,我们得到:
$frac{partial L}{partialoldsymbol{b}} = {sumlimits_{t = 1}^{T}{frac{partial L}{partialoldsymbol{b}^{(t)}} =}}{sumlimits_{t = 1}^{T}{left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{b}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(t)}} = {sumlimits_{t = 1}^{T}{diagleft( {sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)} ight)oldsymbol{delta}^{(t)}}}}}$
$frac{partial L}{partialoldsymbol{U}} = {sumlimits_{t = 1}^{T}{frac{partial L}{partialoldsymbol{U}^{(t)}} =}}{sumlimits_{t = 1}^{T}{left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{U}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(t)}} = {sumlimits_{t = 1}^{T}{diagleft( {sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)} ight)oldsymbol{delta}^{(t)}left( oldsymbol{x}^{(t)} ight)^{T}}}}}$
4.3 RNN发生梯度消失与梯度爆炸的原因分析
上一节我们得到了从$oldsymbol{h}^{(t + 1)}$到$oldsymbol{h}^{(t)}$的递推公式:
$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)oldsymbol{W}$
在求$oldsymbol{h}^{(t)}$的时候,我们需要从$oldsymbol{h}^{(T)}$开始根据上面这个公式一步一步推到$oldsymbol{h}^{(t)}$,可以想象$oldsymbol{W}$在这期间会被疯狂地连乘。当我们要求某个时间步$t$下的$frac{partial L}{partialoldsymbol{W}^{(t)}}$时,这一堆连乘的$oldsymbol{W}$也会被带上。结果就是(粗略地分析),如果$oldsymbol{W}$里的值都比较大,就会发生梯度爆炸,反之则发生梯度消失。
如果本文对您有所帮助的话,不妨点下“推荐”让它能帮到更多的人,谢谢。
参考资料
- 书籍:《Deep Learning》(深度学习)