• TensorFlow——MNIST手写数字识别


    MNIST手写数字识别
    MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/
     
    一、数据集介绍:
    MNIST是一个入门级的计算机视觉数据集
    下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)
     
    二、TensorFlow实现MNIST手写数字识别
    (1)构建一个只有输入层和输出层的简单神经网络模型,使用二次代价函数和梯度下降算法进行优化;代码如下:
    #TensorFlow实现MNIST手写数字识别-简单版本
    import tensorflow as tf
    #Tensorflow提供了一个类来处理MNIST数据
    from tensorflow.examples.tutorials.mnist import input_data
     
    #载入数据集
    mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
     
    #设置每个批次的大小
    batch_size=100
    #计算一共有多少个批次
    n_batch=mnist.train.num_examples//batch_size
     
    #定义两个placeholder
    x=tf.placeholder(tf.float32,[None,784])
    y=tf.placeholder(tf.float32,[None,10])
     
    #创建一个简单的神经网络(只有输入层和输出层)
    Weights=tf.Variable(tf.zeros([784,10]))
    biases=tf.Variable(tf.zeros([10]))
    prediction=tf.nn.softmax(tf.matmul(x,Weights)+biases)
     
    #定义代价函数(均方差函数)
    loss=tf.reduce_mean(tf.square(y-prediction))
    #定义反向传播算法(使用梯度下降算法)
    train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
     
    #结果存放在一个布尔型列表中(argmax函数返回一维张量中最大的值所在的位置)
    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
     
    #求准确率(tf.cast将布尔值转换为float型)
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
     
    #创建会话
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer()) #初始化变量
        #训练次数
        for i in range(21):
            for batch in range(n_batch):
                batch_xs,batch_ys=mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
     
            acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
            print("Iter"+str(i)+",Testing Accuracy"+str(acc))
    

      结果为:

    (2)模型同上,使用交叉熵函数和梯度下降算法进行优化,
    把上面代码的代价函数改为下面的交叉熵代价函数:
    loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction)) 
    

      结果为:

    (3)构建一个多层的神经网络模型,使用交叉熵函数和梯度下降算法进行优化,添加Dropout防止过拟合;
    模型结构如下:
    代码如下:
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
     
    #载入数据集
    mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
     
    #设置每个批次的大小
    batch_size=100
    #计算一共有多少个批次
    n_batch=mnist.train.num_examples//batch_size
     
    #定义三个placeholder
    x=tf.placeholder(tf.float32,[None,784])
    y=tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)  #存放百分率
     
    #创建一个多层神经网络模型
    #第一个隐藏层
    W1=tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
    b1=tf.Variable(tf.zeros([2000])+0.1)
    L1=tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop=tf.nn.dropout(L1,keep_prob) #keep_prob设置工作状态神经元的百分率
    #第二个隐藏层
    W2=tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
    b2=tf.Variable(tf.zeros([2000])+0.1)
    L2=tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
    L2_drop=tf.nn.dropout(L2,keep_prob)
    #第三个隐藏层
    W3=tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
    b3=tf.Variable(tf.zeros([1000])+0.1)
    L3=tf.nn.tanh(tf.matmul(L2_drop,W3)+b3)
    L3_drop=tf.nn.dropout(L3,keep_prob)
    #输出层
    W4=tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
    b4=tf.Variable(tf.zeros([10])+0.1)
    prediction=tf.nn.softmax(tf.matmul(L3_drop,W4)+b4)
     
    #定义交叉熵代价函数
    loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
    #定义反向传播算法(使用梯度下降算法)
    train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
     
    #结果存放在一个布尔型列表中(argmax函数返回一维张量中最大的值所在的位置)
    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
     
    #求准确率(tf.cast将布尔值转换为float型)
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
     
    #创建会话
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer()) #初始化变量
        #训练次数
        for i in range(21):
            for batch in range(n_batch):
                batch_xs,batch_ys=mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:1.0})
            #测试数据计算出的准确率
            test_acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
            print("Iter"+str(i)+",Testing Accuracy"+str(test_acc))
    

      结果为:

     
  • 相关阅读:
    管理员必备的Linux系统监控工具
    kafka入门:简介、使用场景、设计原理、主要配置及集群搭建(转)
    RedHat linux配置yum本地资源
    RedHat Linux RHEL6配置本地YUM源
    c语言中的fgets函数
    sprintf()函数的用法
    spring boot整合JWT例子
    spring boot 自定义过滤器链
    (转)ArrayList和LinkedList的几种循环遍历方式及性能对比分析
    (转)Springboot 中filter 注入对象
  • 原文地址:https://www.cnblogs.com/asialee/p/9245368.html
Copyright © 2020-2023  润新知