• TensorFlow学习系列(四):minist实例--用简单的神经网络训练和测试


    神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输出层。

    数据从输入层输入,在隐藏层进行加权变换,最后在输出层进行输出。输出的时候,我们可以使用softmax回归,输出属于每个类别的概率值。

    其中,x1,x2,x3为输入数据,经过运算后,得到三个数据属于某个类别的概率值y1,y2,y3. 用简单的公式表示如下:

    在训练过程中,我们将真实的结果和预测的结果相比(交叉熵比较法),会得到一个残差。公式如下:

    y 是我们预测的概率值, y' 是实际的值。这个残差越小越好,我们可以使用梯度下降法,不停地改变W和b的值,使得残差逐渐变小,最后收敛到最小值。这样训练就完成了,我们就得到了一个模型(W和b的最优化值)。

    完整代码:

    # coding:utf-8
    import tensorflow as tf
    import tensorflow.examples.tutorials.mnist.input_data as input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)#导入mnist数据集
    x = tf.placeholder(tf.float32,[None,784])#表示输入图像是n个784维向量
    w = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))#用全部为零的张量来初始化w和b,根据y=xw+b,可知[10]=[784]*[784,10]+[10]
    y = tf.nn.softmax(tf.matmul(x,w)+b)#初始化softmax模型,y=xw+b
    y_true = tf.placeholder('float',[None,10])#正确值
    cross_entropy = -tf.reduce_sum(y_true*tf.log(y))#计算交叉熵
    
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#使用梯度下降算法以0.01的学习速率最小化交叉熵
    
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_true,1))#预测返回一组布尔值例如[True,False,True,True]会变成[1,0,1,1]
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))
    
    init = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init)
        for i in range(1000):
            batch_xs,batch_ys = mnist.train.next_batch(100)
            sess.run(train_step,feed_dict={x:batch_xs,y_true:batch_ys})
            if i%100==0:
                print"accuracy:",sess.run(accuracy,feed_dict={x:mnist.test.images,y_true:mnist.test.labels})
                print"correct_prediction:",sess.run(correct_prediction,feed_dict={x:mnist.test.images,y_true:mnist.test.labels})

    每训练100次,测试一次,随着训练次数的增加,测试精度也在增加。训练结束后,1W行数据测试的平均精度为91%左右。

  • 相关阅读:
    hdu5834 Magic boy Bi Luo with his excited tree 【树形dp】
    POJ2152 Fire 【树形dp】
    POJ1848 Tree 【树形dp】
    hdu3586 Information Disturbing 【树形dp】
    BZOJ4557 [JLoi2016]侦察守卫 【树形dp】
    BZOJ4000 [TJOI2015]棋盘 【状压dp + 矩阵优化】
    BZOJ1487 [HNOI2009]无归岛 【仙人掌dp】
    BZOJ4002 [JLOI2015]有意义的字符串 【数学 + 矩乘】
    洛谷P3832 [NOI2017]蚯蚓排队 【链表 + 字符串hash】
    3-3 银行业务队列简单模拟
  • 原文地址:https://www.cnblogs.com/zhoulixue/p/6437862.html
Copyright © 2020-2023  润新知