刚开始学Tensorflow,这里记录学习中的点点滴滴,希望能和大家共同进步。
Cuda和Tensorflow的安装请参考上一篇博客:http://www.cnblogs.com/roboai/p/7768191.html
Tensorflow简单介绍
我们知道,一维的数据可以用数组表示,二维可以用矩阵表示,那么三维或三维以上呢?比如图像,实际上就是一个三维数据[h,w,c],高、宽、通道数,对于灰度图来说,通道数为1,而对于彩色图像,通道数为3。对于这种三维或三维以上的数据,我们称之为张量(tensor),所以顾名思义,Tensorflow的意思就是张量的流动,Tensorflow将数据打包成一个个张量,由四个维度构成,分别是[batch, height, width, channels]
,然后在各个节点之间传递。
节点是Tensorflow里另一重要的概念,对张量的操作称之为节点,一系列的节点构成图。接触过Caffe的朋友可能发现了,这和Caffe里的blob、layer、net是一致的。不同的是,我们需要启动一个会话来计算图,这是Tensorflow的内在机制所决定的。Tensorflow依赖于一个高效的C++后端来进行计算,与后端的这个连接叫做session。一般而言,使用TensorFlow程序的流程是先创建一个图,然后在session中启动它。其思想是先让我们描述一个交互操作图,然后完全将其运行在Python外部。这样做的目的是为了避免频繁切换Python环境和外部环境时需要的开销。如果你想在GPU或者分布式环境中计算时,这一开销会非常可怖,这一开销主要可能是用来进行数据迁移,并不能对计算做出贡献。
我们构建一个简单的图来说明以上过程,改图包含三个节点(两个源节点和一个矩阵乘法节点),然后启动一个会话计算图得到输出结果,最后需要关闭会话。当然也可以使用with代码块实现自动关闭,效果是一样的。
# coding=utf-8 import tensorflow as tf # 该图包含3个节点(两个源节点和乘法节点) matrix1 = tf.constant([[3, 3]]) matrix2 = tf.constant([[2], [2]]) product = tf.matmul(matrix1, matrix2) # 调用会话启动图 sess = tf.Session() result = sess.run(product) # 输出结果并关闭会话 print result sess.close() # 使用“with”代码块自动关闭, 该方法更简洁 with tf.Session() as sess: result = sess.run(product) print result
输出结果为
[[12]]
[[12]]
MNIST数据集
MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片,也包含每一张图片对应的标签,告诉我们这个是数字几。新建一个get.sh文件,写入以下内容,执行该文件就可以下载该数据集。下载下来的数据集被分成两部分,60000行的训练数据集和10000行的测试数据集。每一张图片包含28X28个像素点,我们可以把图片展开成一个向量,长度是 28x28 = 784。
#!/usr/bin/env sh # This scripts downloads the mnist data and unzips it. DIR="$( cd "$(dirname "$0")" ; pwd -P )" cd "$DIR" echo "Downloading..." for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte do if [ ! -e $fname ]; then wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz fi done
Softmax Regression与Cross Entropy
在本文中,我们将采用最简单的网络来预测输入图片中的数字,整个网络仅由一个Softmax Regression构成,数学模型可以写作(y=softmax(Wx+b))。假设(y')是实际分布,(y)是预测分布,Cross Entropy的定义是(loss=sum{y'log{y}})。关于Softmax Regression的反向传递及Cross Entropy的物理含义请参考以下两篇博客,这里就不展开写了。
http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92
http://blog.csdn.net/rtygbwwwerr/article/details/50778098
全连接网络实现手写数字识别
下面终于进入正题了,我们有了数据集,同时也了解了算法流程,剩下的就是写代码实现了。首先是导入包,由于Tensorflow帮我们写了一部分数据读写的程序,我们这里就直接用了。
# coding=utf-8 import tensorflow.examples.tutorials.mnist.input_data as input_data import tensorflow as tf # 导入数据, 强烈建议预先下载 mnist = input_data.read_data_sets("data/", one_hot=True)
这里数据可以用我前面给出的get.sh下载,然后放入data文件夹目录下,我之前是直接用input_data.read_data_sets("data/", one_hot=True)下载的,结果半天下载不下来,所以这里还是建议预先下载吧,用get.sh下载比较快。然后是程序的主要部分。
# 训练集占位符:28*28=784 x = tf.placeholder(tf.float32, [None, 784]) # 初始化参数 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) # 输出结果 y = tf.nn.softmax(tf.matmul(x, W) + b) # 真实值 y_ = tf.placeholder(tf.float32, [None, 10]) # 计算交叉熵 crossEntropy = -tf.reduce_sum(y_*tf.log(y)) # 训练策略 trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy) # 初始化参数值 init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) # 开始训练:循环训练1000次 for i in range(1000): batchXs, batchYs = mnist.train.next_batch(100) sess.run(trainStep, feed_dict={x: batchXs, y_: batchYs}) # 评估模型 correctPrediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32)) print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
这里用的是占位符的方式传入数据,占位符的尺寸为[None, 784],这里的None
表示此张量的第一个维度可以是任何长度的。
权重值W和偏置量b使用Variable来表示,
一个Variable
代表一个可修改的张量,存在在Tensorflow的用于描述交互性操作的图中。它们可以用于计算输入值,也可以在计算中被修改。对于各种机器学习应用,一般都会有模型参数,都可以用Variable
表示。在这里,我们都用全为零的张量来初始化W
和b。
只需要一行代码就可以实现我们的模型y = tf.nn.softmax(tf.matmul(x, W) + b),同样损失函数也只需要一行代码crossEntropy = -tf.reduce_sum(y_*tf.log(y))。
以0.01的学习速率,采用梯度下降法最小化交叉熵,对应的代码为trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy)。
然后初始化参数并训练,定义训练次数为1000,每次随机地选取100图像进行计算。
最后对得到的模型使用测试数据进行评估,评估结果表明精度达到0.9148(每次都不一样,在91%左右徘徊)。
至此,我们采用最简单的一个全连接网络实现了一个手写数字识别的网络,剩下的工作是将这个网络及参数保存,采用自己的图片进行识别,进一步感受这个网络的效果,这一部分将在后续的工作中进行。同时我们可以说这个网络过于简单了,91%的识别效果也远远达不到我们的需求,如何进一步提高网络的精度是我们关注的重点。
关于会话
会话(session)提供在图中执行操作的一些方法。一般的模式是:
- 建立会话,此时会生成一张空图;
- 在会话中添加节点和边,形成一张图;
- 执行图
在调用Session对象的run()方法来执行图时,传入一些Tensor,这个过程叫填充(feed);返回的结果类型根据输入的类型而定,这个过程叫取回(fetch)。
会话是图交互的桥梁,一个会话可以有多个图,会话可以修改图的结构,也可以往图中注入数据进行计算。因此,会话主要由两个API接口--Extend和Run。Extend操作是在Graph中添加节点和边,Run操作是输入计算的节点和填充必要的数据后,进行计算,并输出运算结果。
关于节点与图
图中的节点又称为算子,它代表一个操作(Operation,op),一般用来表示施加的数学运算,也可以表示数据输入(feed in)的起点以及输出(push out)的终点,或者是读取/写入持久变量(persistent variable)的终点。
如果不显式添加一个默认图,系统会自动设置一个全局的默认图。所设置的默认图,在模块范围内定义的节点都将默认加入默认图中。
关于可视化
可视化时,需要在程序中给必要的节点添加摘要(summary),摘要会收集该节点的数据,并标记上第几步、时间戳等标识,写入事件文件(event file)中。
模型存储与加载
TensorFLow的API提供了两种方式存储和加载模型:
(1)生成检查点文件,拓展名一般为.ckpt,通过tf.train.Saver.save()生成。它包含权重和程序中定义的变量,不包含图结构。如果需要在另一个程序中使用,需要重新构建图结构,并告诉TensorFlow如何处理这些权重。
(2)生成图协议文件,这是一个二进制文件,拓展名一般为.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。
模型训练之Momentum
Momentum是模拟物理学中的动量的概念,更新时在一定程度上保留之前的更新方向,利用当前的批次再微调本次的更新参数,因此引入了一个新的变量v(速度),作为前几次梯度的累加。因此,Momentum能够改善训练过程,在下降初期,前后梯度一致时,能够加速学习;在下降的中后期,在局部最小值附近来回震荡时,能够抑制震荡,加快收敛。