第一步仍然是导入库和数据集。
''' To classify images using a reccurent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample. '''
这里我们设定了各种参数,此时的n_steps是指按照28个时间点,依次输入数据。因为mnist数据集是28*28的,每次只能输入一行,所以n_input是28,分成28次顺序输入。
y是输出,仍然是一个10维数组。None设为并行数。
注意这里的权值,weights矩阵变成了[n_hidden,n_classes],前面是指这个RNN一共有多少个隐藏的单元(如下图):
中间那些绿色的状态,可以看作是隐藏的单元。之后建立一个矩阵,将隐藏单元转换成一个数组。当然再分类问题上,更常见的做法是在RNN上加一层全连接层(就像BP网络一样),在进行输出,在效果上比纯粹进行一次矩阵运算会好一些。
来看网络结构。
tf.unstack() 将给定的R维张量拆分成R-1维张量:将value根据axis分解成num个张量,返回的值是list类型,如果没有指定num则根据axis推断出。在这个李子中,我们通过这个函数将x按照第二维切开来,分成一行一行的,x变成了一个list,每个元素就是一次输出,list的长度就是step数。
BasicLSTMCell()函数中的forget_bias参数,一般设置为1.0。
众所周知,lstm单元有两个输出,一个是h,一个是$s_n$。此处我们将两个豆取回,然后RNN最终输出那个分类,即做了一次线性变换的值。
我们实例化RNN,然后定义cost和optimizer。
通过比较label和输出,来得到一个正确与否的0-1矩阵,之后计算精确度。
最后初始化所有的变量,开始训练。
其中用while循环来控制循环体,将x,y作为feed_dict输入网络,迭代运算损失。
最后进行测试。sess.run()这里表示进行accuracy的计算。