• TFboy养成记 MNIST Classification (主要是如何计算accuracy)


     参考:莫烦。

    主要是运用的MLP。另外这里用到的是批训练:

    这个代码很简单,跟上次的基本没有什么区别。

    这里的lossfunction用到的是是交叉熵cross_entropy.可能网上很多形式跟这里的并不一样。

    这里一段时间会另开一个栏。专门去写一些机器学习上的一些理论知识。

    这里代码主要写一下如何计算accuracy:

    1 def getAccuracy(v_xs,v_ys):
    2     global y_pre
    3     y_v = sess.run(y_pre,feed_dict={x:v_xs})
    4     correct_prediction = tf.equal(tf.arg_max(y_v,1),tf.arg_max(v_ys,1))
    5     accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    6     result = sess.run(accuracy,feed_dict={x:v_xs,y:v_ys})
    7     
    8     return result

    首先得到ground truth,与预测值,然后对着预测值得到tf,arg_max---->你得到的是以float tensor,tensor上的各个值是各个分类结果的可能性,而argmax函数就是求里面的最大值的下表也就是结果。

    注意这里每次得到的是一个batch的结果,也就是说以一个【9,1,2,、。。。。】的这种tensor,所以最后用tf.equal得到一个表示分类值与实际类标是否相同的Bool型tensor。最后把tensor映射到0,1,两个值上就可以了.

    可能会有人问为什么不用int表示而是用float32来表示呢?因为下面腰酸的是准确率,如果是int32,那么按tensorflow的整数除法运算是直接取整数部分不算小数点的。(这几个涉及到的函数在之前的博客)

    全部代码:

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Sun Jun 18 15:31:11 2017
     4 
     5 @author: Jarvis
     6 """
     7 
     8 import tensorflow as tf
     9 import numpy as np
    10 from tensorflow.examples.tutorials.mnist import input_data
    11 
    12 def addlayer(inputs,insize,outsize,activate_func = None):
    13     W = tf.Variable(tf.random_normal([insize,outsize]),tf.float32)
    14     b = tf.Variable(tf.zeros([1,outsize]),tf.float32)
    15     W_plus_b = tf.matmul(inputs,W)+b
    16 
    17     if activate_func == None:
    18         return W_plus_b
    19     else:
    20         return activate_func(W_plus_b)
    21 def getAccuracy(v_xs,v_ys):
    22     global y_pre
    23     y_v = sess.run(y_pre,feed_dict={x:v_xs})
    24     correct_prediction = tf.equal(tf.arg_max(y_v,1),tf.arg_max(v_ys,1))
    25     
    26     accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    27     result = sess.run(accuracy,feed_dict={x:v_xs,y:v_ys})
    28     
    29     return result
    30 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    31 
    32 x  = tf.placeholder(tf.float32,[None,784])
    33 y = tf.placeholder(tf.float32,[None,10])
    34 #h1 = addlayer(x,784,14*14,activate_func=tf.nn.softmax)
    35 #y_pre = addlayer(h1,14*14,10,activate_func=tf.nn.softmax)
    36 y_pre = addlayer(x,784,10,activate_func=tf.nn.softmax)
    37 
    38 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pre),reduction_indices=[1]))
    39 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    40 
    41 sess = tf.Session()
    42 sess.run(tf.global_variables_initializer())
    43 for i in range(10001):
    44     x_batch,y_batch = mnist.train.next_batch(100)
    45     sess.run(train_step,feed_dict={x:x_batch,y:y_batch})
    46     
    47     if i % 100 == 0:
    48         print (getAccuracy(mnist.test.images,mnist.test.labels))
    49     
    View Code
  • 相关阅读:
    pymysql 防止sql注入案例
    4、【常见算法】一道经典的额递归题目
    3、【常见算法】寻找非递减序列中绝对值最小值的绝对值
    2、【常见算法】按序重排问题
    9、【排序算法】基数排序
    8、【排序算法】桶排序
    7、【排序算法】归并排序
    6、【排序算法】堆排序
    20、【图】Prim(普里姆)算法
    19、【图】Kruskal(克鲁斯卡尔)算法
  • 原文地址:https://www.cnblogs.com/silence-tommy/p/7045850.html
Copyright © 2020-2023  润新知