• tensorflow之分类学习


    写在前面的话

    MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
    参考视频:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-01-classifier/

    MNIST编程

    代码全文

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets('MNIST_data',one_hot = True)
    
    def add_layer(inputs, in_size, out_size, activation_function=None):
        Weights = tf.Variable(tf.random_normal([in_size, out_size]))
        biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
        Wx_plus_b = tf.matmul(inputs, Weights) + biases
        if activation_function is None:
            outputs = Wx_plus_b
        else:
            outputs = activation_function(Wx_plus_b)
        return outputs
    
    def compute_accuracy(v_xs,v_ys):
    	global prediction
    	y_pre = sess.run(prediction,{xs:v_xs})
    	correct_prediction = tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
    	accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    	result = sess.run(accuracy,{xs:v_xs,ys:v_ys})
    	return result
    
    
    xs = tf.placeholder(tf.float32,[None,784])
    ys = tf.placeholder(tf.float32,[None,10])
    
    # add hiden layer
    prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)
    
    #  the error
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1]))
    
    # train
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    
    for i in range(1000):
    	batch_xs,batch_ys = mnist.train.next_batch(100)
    	sess.run(train_step,{xs:batch_xs,ys:batch_ys})
    	if i % 50 == 0:
    		print(compute_accuracy(mnist.test.images,mnist.test.labels))
    
    
    

    打印结果

    解释

    总的来说,整个程序分为两层,输入层和输出层,没有隐藏层。用到的激励函数为 softmax 函数。

    与之前的 tensorflow之曲线拟合 相比,不同并值得记录的有以下几点:

    • 1 one_hot = True

    表示使用热编码 ,什么是热编码呢?这里举个例子
    如用[1,0,0,0,0,0,0]表示星期一,[0,1,0,0,0,0,0]表示星期二。这里是手写数字识别,识别0~9 共10个数字。所以这里用热编码来表示5的话就是: [0,0,0,0,0,1,0,0,0,0]。在实际预测过程中,预测值可能为 [0,0,0.1,0,0,0.6,0,0,0.3,0] 这样的形式,代表预测到0~9某一数字的概率。

    • 2 prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)

    激励函数为 softmax ,这里为什么用这个函数,对于分类问题, “最后一层输出会使用Softmax函数进行概率化输出”

    • 3 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1]))

    cross_entropy称为交叉熵 ,该误差模型被广泛应用于机器学习的分类问题。

    • 4 mnist.train.next_batch(100)

    这里代表每次从训练样本中拿出100个训练样本样本数据来进行训练。进行1000次迭代循环,实际上使用了100 * 1000 个训练样本。

    • 5 compute_accuracy(mnist.test.images,mnist.test.labels)

    函数本身的作用用来计算训练模型后用测试样本来检测准确率的。与上一条解释对比可以看到变量的不同。一个是 train ,一个是test 。这样做是用训练数据去训练模型,用测试数据来测试模型,更能测试模型的鲁棒性。

    • 6 tf.argmax()

    tf.argmax(vector, 1):返回的是vector中的最大值的索引号,如果vector是一个向量,那就返回一个值,如果是一个矩阵,那就返回一个向量,这个向量的每一个维度都是相对应矩阵行的最大值元素的索引号。

    import tensorflow as tf
    import numpy as np
     
    A = [[1,3,4,5,6]]
    B = [[1,3,4], [2,4,1]]
     
    with tf.Session() as sess:
        print(sess.run(tf.argmax(A, 1)))
        print(sess.run(tf.argmax(B, 1)))
    
    #输出:
    #[4]
    #[2 1]
    

    这里返回最大数值的所引值,实际上就是返回了0~9概率最大的那个数值,然后识别的结果和测试样本的正确结果对比。

    关于第二个参数,这涉及到axis ,对于axis取值0或1,关系的是在计算上矩阵时是 的方向还是 的方向。这个百度一下就知道。

    • 7 tf.cast()

    cast(x, dtype, name=None)
    将x的数据格式转化成dtype.例如,原来x的数据格式是bool, 那么将其转化成float以后,就能够将其转化成0和1的序列。反之也可以。例如

    a = tf.Variable([1,0,0,1,1])
    b = tf.cast(a,dtype=tf.bool)
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    print(sess.run(b))
    #输出[ True False False  True  True]
    

    解析参考:
    https://blog.csdn.net/uestc_c2_403/article/details/72232807
    https://blog.csdn.net/luoganttcc/article/details/70315538

    总结

    • 1 对于线性/非线性拟合
      ① 常用激励函数:relu
      ② 常用平方差,计算误差

    • 2 对于分类问题
      ① 最后一层常用Softmax 激励函数进行概率化输出
      ② 常用交叉熵 ,计算误差

    • 3 对于大批量数据,常用分批次训练,next_batch

  • 相关阅读:
    http://www.kankanews.com/ICkengine/archives/18078.shtml
    c# ArrayList 的排序问题!
    MVC各种传值方式
    MVC3学习第五章 排山倒海第一变母版页,模型
    MVC3学习第三章 剑出鞘之前奏控制器,URL路由
    MVC3学习第二章 剑出鞘之看剑vs2010安装MVC3和建立你的第一个MVC3项目
    MVC3学习第四章 剑出鞘之后续MVC3的新特性之Razor视图解析
    MVC3学习第一章 掀起它的盖头来
    有关匿名函数执行与传参
    ubuntu12.04安装jdk7u79linuxi586.tar.gz
  • 原文地址:https://www.cnblogs.com/maskerk/p/9981861.html
Copyright © 2020-2023  润新知