• TensorFlow实现Softmax Regression识别手写数字


    本章已机器学习领域的Hello World任务----MNIST手写识别做为TensorFlow的开始。MNIST是一个非常简单的机器视觉数据集,是由几万张28像素*28像素的手写数字组成,这些图片只包含灰度值信息。

    下面提取了784维的特征,也就是2828个点展开成一维的结果,所以训练数据是一个55000784的Tensor,label是一个55000*10的tensor。当我们处理多分类任务时,通常需要使用Softmax Regression模型。它的工作原理很简单,将可以判定为某类的特征相加,然后将这些特征转化为判定是这一类的概率。其本质就是多类别逻辑回归。

    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)#从TensorFlow读取数据
    
    print (mnist.train.images.shape,mnist.train.labels.shape)
    print (mnist.test.images.shape,mnist.test.labels.shape)
    print (mnist.validation.images.shape,mnist.validation.labels.shape)
    
    import tensorflow as tf
    sess = tf.InteractiveSession()#创建一个session,之后的运算都在这个session里,不同session的数据和运算是相互独立的
    x = tf.placeholder(tf.float32,[None,784])#输入数据的地方,第一个参数是数据类型,第二个是tensor的shape
    
    W = tf.Variable(tf.zeros([784,10]))#Variable是存储模型参数的,不同于存储数据的tensor一旦使用掉就消失,Variable在模型训练迭代过程中是持久化的。
    b = tf.Variable(tf.zeros([10]))
    
    y = tf.nn.softmax(tf.matmul(x,W)+b)#实现Softmax Regression算法
    
    y_ = tf.placeholder(tf.float32,[None,10])#定义一个真实的label,与下面的结果做比较
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))#计算模型的loss
    
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#定义了损失函数之后,再定义一个优化算法,本代码使用SGD算法
    tf.global_variables_initializer().run()#使用全局参数初始化器初始化参数
    
    for i in range(1000):
        batch_xs,batch_ys = mnist.train.next_batch(100)#每次选择100条数据
        train_step.run({x:batch_xs,y_:batch_ys})#选择好数据之后用SGD算法做迭代
    
    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))#比较预测结果是否准确
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#转化成准确率
    print (accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))#输出结果,正确率为91.57%
  • 相关阅读:
    奇妙的html 和 Css【关于html、Css 开发中重要的细节和一些小奇怪现象】
    JavaWeb项目img标签的图片无法加载的原因及解决方法
    批量建堆(二叉堆【完全二叉堆】)~~批量建堆
    为什么要面向对象(转)
    不从0开始序列的matlab卷积实验
    虚数、傅里叶变换中负频率的意义
    运动的6个自由度
    转载:关于Ω, f, w的前世今生
    利用matplotlib画用于机器学习的K线图练手任务
    信息工程学院——电子信息工程到底学什么?
  • 原文地址:https://www.cnblogs.com/whig/p/10085090.html
Copyright © 2020-2023  润新知