1. RNN
RNN结构图
计算公式:
代码:
1 model = Sequential() 2 model.add(SimpleRNN(7, batch_input_shape=(None, 4, 2))) 3 model.summary()
运行结果:
可见,共70个参数
记输入维度(x的维度,本例中为2)为dx, 输出维度(h的维度, 与隐藏单元数目一致,本例中为7)为dh
则公式中U的shape应该是dh*dx, W的shape因该是dh*dh, b的shape应该是dh*1
这样计算的h(t)维度才能是dh
计算公式:
nums = dh * ( dh + dx ) + dh
括号中可以理解为x和h(t-1)合并
70 = 7 *( 7 + 2 ) + 7
2. LSTM
https://zhuanlan.zhihu.com/p/147496732
参考这篇吧,讲的不错
LSTM单元结构图
代码:
1 model = Sequential() 2 model.add(LSTM(7, batch_input_shape=(None, 4, 2))) 3 model.summary()
运行结果:
计算公式:
nums = 4 * [ dh * (dh + dx) + dh ]
280 = 4 * [ 7 * (7 + 2) + 7 ]
3. GRU
GRU单元结构图
代码:
1 model = Sequential() 2 model.add(GRU(7, batch_input_shape=(None, 4, 2))) 3 model.summary()
运行结果:
计算方式:
nums = 3 * [ dh * (dh + dx) + dh ]
210 = 3 * [ 7 * (7 + 2) + 7 ]