• 基于cnn的MNIST_data手写辨识


    # -*- coding: utf-8 -*-
    """
    Created on Sat May 26 16:40:17 2018


    @author: 被遗弃的庸才
    """


    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data


    mnist=input_data.read_data_sets('MNIST_data',one_hot=True)#本地没有这个文件就会下载,有的话就会直接用




    #定义计算正确率的函数
    def compute_accuracy(v_xs,v_ys):
        y_pre=sess.run(prediction,feed_dict={xs:v_xs,ys:v_ys,keep_prob:1})
        correct_prediction=tf.equal(tf.argmax(y_pre,1),tf.arg_max(v_ys,1))#argmax1是按照列优先返回的下标
        accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#这里完成一个类型的转化
        result=sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys})
        return result


    def weight_variable(shape):
        inital=tf.Variable(tf.random_normal(shape))*0.1#这个是自己测试出来的,每次使用relu的时候都要玄学一下
        #inital=tf.truncated_normal(shape,stddev=0.1)#因为google也是这样玩的,所以这里没有用random_normal
        return tf.Variable(inital)
    def bias_variable(shape):
        inital=tf.constant(0.1,shape=shape)
        return tf.Variable(inital)
    def conv2d(x,W):
        return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')#strides是每一步多长
    def max_pool_2x2(x):
        return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    #定义输入变量
    xs=tf.placeholder(tf.float32,[None,784])
    ys=tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)
    x_image=tf.reshape(xs,[-1,28,28,1])#一维的只用1就行了,-1代表不管sample的维度


    #conv1 layer
    #这个相当于就是fliter
    W_conv1=weight_variable([5,5,1,32])#patch 5*5->劫的大小  size =1和之前的1是一样的 out size 1*32
    b_conv1=bias_variable([32])
    h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)#28*28*32
    h_pool1=max_pool_2x2(h_conv1)#14*14*32
    #cov2
    W_conv2=weight_variable([5,5,32,64])#patch 5*5->劫的大小  size =1 out size 1*64
    b_conv2=bias_variable([64])
    h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)#14*14*64
    h_pool2=max_pool_2x2(h_conv2)#7*7*64




    #function1 layer
    W_fc1=weight_variable([7*7*64,1024])
    b_fc1=bias_variable([1024])
    h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])#这里要进行转换,因为之前的数据是三维的,
    hfc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
    hfcl_drop=tf.nn.dropout(hfc1,keep_prob)


    #func2
    W_fc2=weight_variable([1024,10])
    b_fc2=bias_variable([10])
    prediction=tf.nn.softmax(tf.matmul(hfcl_drop,W_fc2)+b_fc2)


    #损失函数
    loss=tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1]))
    train=tf.train.AdamOptimizer(0.0001).minimize(loss)




    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for i in range(1001):
            batch_x,batch_y=mnist.train.next_batch(100)
            sess.run(train,feed_dict={xs:batch_x,ys:batch_y,keep_prob:0.5})
            if i%50==0:
                print(compute_accuracy(mnist.test.images,mnist.test.labels))
            



























































  • 相关阅读:
    指定HTML标签属性 |Specifying HTML Attributes| 在视图中生成输出URL |高级路由特性 | 精通ASP-NET-MVC-5-弗瑞曼
    传递额外的值 Passing Extra Values |在视图中生成输出URL | 高级路由特性 | 精通ASP-NET-MVC-5-弗瑞曼
    以其他控制器为目标 在视图中生成输出URL
    数组与指针(数组中所有元素的和)
    OC中的指针
    UIScrollView创建相册
    开发之UI篇
    TabBarController
    适配ipone5
    NSDate 哪些事
  • 原文地址:https://www.cnblogs.com/csnd/p/16675642.html
Copyright © 2020-2023  润新知