LSTM神经元行为分析
LSTM 公式可以描述如下:
itf
感觉比较新奇的一点是通过点乘矩阵使用‘门’控制数据流的取舍,和卷积神经网络的激活过程有一点点相似。
反向传播时,通过链式法则一个变量一个变量后推比较清晰。
反向传播时注意Ct节点,它既是本层的输出,也是本层另一个输出ht的输入节点,即它的梯度由两部分组成——上层回传梯度&ht反向传播梯度
向前传播
单个LSTM神经元向前传播
def lstm_step_forward(x, prev_h, prev_c, Wx, Wh, b): """ Forward pass for a single timestep of an LSTM. The input data has dimension D, the hidden state has dimension H, and we use a minibatch size of N. Inputs: - x: Input data, of shape (N, D) - prev_h: Previous hidden state, of shape (N, H) - prev_c: previous cell state, of shape (N, H) - Wx: Input-to-hidden weights, of shape (D, 4H) - Wh: Hidden-to-hidden weights, of shape (H, 4H) - b: Biases, of shape (4H,) Returns a tuple of: - next_h: Next hidden state, of shape (N, H) - next_c: Next cell state, of shape (N, H) - cache: Tuple of values needed for backward pass. """ next_h, next_c, cache = None, None, None ############################################################################# # TODO: Implement the forward pass for a single timestep of an LSTM. # # You may want to use the numerically stable sigmoid implementation above. # ############################################################################# _, H = prev_h.shape a = x.dot(Wx) + prev_h.dot(Wh) + b i,f,o,g = sigmoid(a[:,:H]),sigmoid(a[:,H:2*H]),sigmoid(a[:,2*H:3*H]),np.tanh(a[:,3*H:]) next_c = f*prev_c + i*g next_h = o*np.tanh(next_c) cache = [i, f, o, g, x, prev_h, prev_c, Wx, Wh, b, next_c] return next_h, next_c, cache
层LSTM神经元向前传播
def lstm_forward(x, h0, Wx, Wh, b): """ Forward pass for an LSTM over an entire sequence of data. We assume an input sequence composed of T vectors, each of dimension D. The LSTM uses a hidden size of H, and we work over a minibatch containing N sequences. After running the LSTM forward, we return the hidden states for all timesteps. Note that the initial cell state is passed as input, but the initial cell state is set to zero. Also note that the cell state is not returned; it is an internal variable to the LSTM and is not accessed from outside. Inputs: - x: Input data of shape (N, T, D) - h0: Initial hidden state of shape (N, H) - Wx: Weights for input-to-hidden connections, of shape (D, 4H) - Wh: Weights for hidden-to-hidden connections, of shape (H, 4H) - b: Biases of shape (4H,) Returns a tuple of: - h: Hidden states for all timesteps of all sequences, of shape (N, T, H) - cache: Values needed for the backward pass. """ h, cache = None, None ############################################################################# # TODO: Implement the forward pass for an LSTM over an entire timeseries. # # You should use the lstm_step_forward function that you just defined. # ############################################################################# N,T,D = x.shape next_c = np.zeros_like(h0) next_h = h0 h, cache = [], [] for i in range(T): next_h, next_c, cache_step = lstm_step_forward(x[:,i,:], next_h, next_c, Wx, Wh, b) h.append(next_h) cache.append(cache_step) h = np.array(h).transpose(1,0,2) #<-----------注意分析h存储后的维度是(T,N,H),需要转置为(N,T,H) return h, cache
反向传播
注意实际反向传播时,初始的C梯度是自己初始化的,而h梯度继承自高层(分类或者h到词袋的转化层,h层和RNN实际相同)
单个LSTM神经元反向传播
def lstm_step_backward(dnext_h, dnext_c, cache): """ Backward pass for a single timestep of an LSTM. Inputs: - dnext_h: Gradients of next hidden state, of shape (N, H) - dnext_c: Gradients of next cell state, of shape (N, H) - cache: Values from the forward pass Returns a tuple of: - dx: Gradient of input data, of shape (N, D) - dprev_h: Gradient of previous hidden state, of shape (N, H) - dprev_c: Gradient of previous cell state, of shape (N, H) - dWx: Gradient of input-to-hidden weights, of shape (D, 4H) - dWh: Gradient of hidden-to-hidden weights, of shape (H, 4H) - db: Gradient of biases, of shape (4H,) """ dx, dprev_h, dprev_c, dWx, dWh, db = None, None, None, None, None, None ############################################################################# # TODO: Implement the backward pass for a single timestep of an LSTM. # # # # HINT: For sigmoid and tanh you can compute local derivatives in terms of # # the output value from the nonlinearity. # ############################################################################# i, f, o, g, x, prev_h, prev_c, Wx, Wh, b, next_c = cache do = dnext_h*np.tanh(next_c) dnext_c += dnext_h*o*(1-np.tanh(next_c)**2) #<-----------上面分析行为有提到这里的求法 di, df, dg, dprev_c = (g, prev_c, i, f) * dnext_c da = np.concatenate([i*(1-i)*di, f*(1-f)*df, o*(1-o)*do, (1-g**2)*dg],axis=1) db = np.sum(da,axis=0) dx, dWx, dprev_h, dWh = (da.dot(Wx.T), x.T.dot(da), da.dot(Wh.T), prev_h.T.dot(da)) return dx, dprev_h, dprev_c, dWx, dWh, db
层LSTM神经元反向传播
def lstm_backward(dh, cache): """ Backward pass for an LSTM over an entire sequence of data.] Inputs: - dh: Upstream gradients of hidden states, of shape (N, T, H) - cache: Values from the forward pass Returns a tuple of: - dx: Gradient of input data of shape (N, T, D) - dh0: Gradient of initial hidden state of shape (N, H) - dWx: Gradient of input-to-hidden weight matrix of shape (D, 4H) - dWh: Gradient of hidden-to-hidden weight matrix of shape (H, 4H) - db: Gradient of biases, of shape (4H,) """ dx, dh0, dWx, dWh, db = None, None, None, None, None ############################################################################# # TODO: Implement the backward pass for an LSTM over an entire timeseries. # # You should use the lstm_step_backward function that you just defined. # ############################################################################# N,T,H = dh.shape _, D = cache[0][4].shape dx, dh0, dWx, dWh, db = [], np.zeros((N, H), dtype='float32'), np.zeros((D, 4*H), dtype='float32'), np.zeros((H, 4*H), dtype='float32'), np.zeros(4*H, dtype='float32') step_dprev_h, step_dprev_c = np.zeros((N,H)),np.zeros((N,H)) for i in xrange(T-1, -1, -1): step_dx, step_dprev_h, step_dprev_c, step_dWx, step_dWh, step_db = lstm_step_backward(dh[:,i,:] + step_dprev_h, step_dprev_c, cache[i]) dx.append(step_dx) # 每一个输入节点都有自己的梯度 dWx += step_dWx # 层共享参数,需要累加和 dWh += step_dWh # 层共享参数,需要累加和 db += step_db # 层共享参数,需要累加和 dh0 = step_dprev_h # 只有最初输入的h0,即feature的投影(图像标注中),需要存储梯度 dx = np.array(dx[::-1]).transpose((1,0,2)) return dx, dh0, dWx, dWh, db