• mnist手写数字检测


    # -*- coding: utf-8 -*-
    """
    Created on Tue Apr 23 06:16:04 2019
    
    @author: 92958
    """
    
    import numpy as np
    import tensorflow as tf
    
    #下载并载入mnist(55000*28*28图片)
    #from tensorflow.examples.tutorials.mnist import input_data
    
    #创造变量mnist,用特定函数,接收
    mnist = input_data.read_data_sets('F:\python\TensorFlow\mnist\mnist_data\',one_hot=True)
    #one_hot独热码,例,:0001000000
    
    #None表示tensor的第一个维度可以是任何长度
    input_x = tf.placeholder(tf.float32,[None,28*28])/255.   #除255表示255个灰度值
    output_y = tf.placeholder(tf.int32,[None, 10])       #10个输出标签
    input_x_images = tf.reshape(input_x, [-1,28,28,1])     #改变形状之后的输出  
          
    #从Test选3000个数据
    test_x = mnist.test.images[:3000]#图片
    test_y = mnist.test.labels[:3000]#标签
    
    #日志
    path = "F:\python\TensorFlow\mnist\log"
    
    #构建第一层神经网络
    conv1 = tf.layers.conv2d(
            inputs=input_x_images,   #形状28.28.1
            filters =32,             #32个过滤器输出深度32
            kernel_size=[5,5],       #过滤器在二维的大小5*5
            strides=1,               #步长为1
            padding='same',          #same表示输出大小不变,因此外围补零两圈
            activation=tf.nn.relu    #激活函数为relu
            )
    #输出得到28*28*32
    
    #第一层池化层pooling(亚采样)
    pool1 = tf.layers.max_pooling2d(
            inputs=conv1,       #形状为28*28*32
            pool_size=[2,2],     #过滤器大小2*2
            strides=2,          #步长为2
            )
    #形状14*14*32
    
    #第二层卷积层
    conv2 = tf.layers.conv2d(
            inputs=pool1,            #形状14*14*32
            filters =64,             #32个过滤器输出深度64
            kernel_size=[5,5],       #过滤器在二维的大小5*5
            strides=1,               #步长为1
            padding='same',          #same表示输出大小不变,因此外围补零两圈
            activation=tf.nn.relu    #激活函数为relu
            )
    #形状14*14*64
    
    #第二层池化层pooling(亚采样)
    pool2 = tf.layers.max_pooling2d(
            inputs=conv2,       #形状为14*14*64
            pool_size=[2,2],     #过滤器大小2*2
            strides=2,          #步长为2
            )
    #形状7*7*64
    
    #平坦化(flat)
    flat = tf.reshape(pool2,[-1,7*7*64])   #形状7*7*64
    
    #全连接层
    dense = tf.layers.dense(inputs = flat, 
                            units=1024, #有1024个神经元
                            activation=tf.nn.relu#激活函数relu
                            )
    
    #dropout:丢弃50%,rate=0.5
    dropout = tf.layers.dropout(inputs=dense, rate=0.5)
    
    #10个神经元的全连接层,这里不用激活函数来做非线性化
    logits=tf.layers.dense(inputs=dropout,units=10)#输出1*1*10
    
    #计算误差,(计算cross entropy(交叉熵),再用softmax计算百分比概率)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=output_y,
                                           logits=logits)
    #Adam优化器来最小化误差
    train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
    
    #精度
    #返回
    accuracy = tf.metrics.accuracy(
            labels=tf.argmax(output_y,axis=1),
            predictions=tf.argmax(logits,axis=1),)[1]
    
    
    
    #创建会话
    sess = tf.Session()
    
    #初始化变量全局和局部
    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    
    sess.run(init)
    writer =tf.summary.FileWriter(path,sess.graph)
    for i in range(1000):
        batch = mnist.train.next_batch(50)
        #从train数据集里取下一个50个样本
        train_loss,train_op_= sess.run([loss,train_op],
                                       {input_x:batch[0],output_y:batch[1]})
        if i%100==0:
            test_accuracy = sess.run(accuracy,
                                     {input_x:test_x,output_y:test_y})
            print("Step=",i)
            print("Train loss=",train_loss)
            print("Test accuracy=",test_accuracy)
    
            
    #测试
    test_output=sess.run(logits,{input_x:test_x[:20]})
    inferenced_y=np.argmax(test_output,1)
    print(inferenced_y,'推测')
    print(np.argmax(test_y[:20],1),'真实')
    

    mnist数据集http://yann.lecun.com/exdb/mnist/

  • 相关阅读:
    CF #305(Div.2) D. Mike and Feet(数学推导)
    CF #305 (Div. 2) C. Mike and Frog(扩展欧几里得&&当然暴力is also no problem)
    2015百度之星资格赛.1004放盘子(数学推导)
    poj.1988.Cube Stacking(并查集)
    lightoj.1048.Conquering Keokradong(二分 + 贪心)
    CMD 命令汇总
    PLSQL 安装与配置 Oracle
    用 jQuery 实现简单倒计时功能
    C# 从服务器下载文件并保存到客户端
    用 NPOI 组件实现数据导出
  • 原文地址:https://www.cnblogs.com/nnmaitian/p/10759866.html
Copyright © 2020-2023  润新知