• 8.Dropout


    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次的大小
    batch_size = 64
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义三个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)
    
    # 784-1000-500-10
    W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
    b1 = tf.Variable(tf.zeros([1000])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop = tf.nn.dropout(L1,keep_prob) 
    
    W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
    b2 = tf.Variable(tf.zeros([500])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
    L2_drop = tf.nn.dropout(L2,keep_prob) 
    
    W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
    b3 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)
    
    #交叉熵
    loss = tf.losses.softmax_cross_entropy(y,prediction)
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(31):
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5})
            
            test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
            train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    Iter 0,Testing Accuracy 0.9201,Training Accuracy 0.91234547
    Iter 1,Testing Accuracy 0.9256,Training Accuracy 0.9229636
    Iter 2,Testing Accuracy 0.9359,Training Accuracy 0.9328182
    Iter 3,Testing Accuracy 0.9375,Training Accuracy 0.93716365
    Iter 4,Testing Accuracy 0.9408,Training Accuracy 0.9411273
    Iter 5,Testing Accuracy 0.9407,Training Accuracy 0.94365454
    Iter 6,Testing Accuracy 0.9472,Training Accuracy 0.9484909
    Iter 7,Testing Accuracy 0.9472,Training Accuracy 0.9502
    Iter 8,Testing Accuracy 0.9516,Training Accuracy 0.95336366
    Iter 9,Testing Accuracy 0.9522,Training Accuracy 0.95552725
    Iter 10,Testing Accuracy 0.9525,Training Accuracy 0.95632726
    Iter 11,Testing Accuracy 0.9566,Training Accuracy 0.9578909
    Iter 12,Testing Accuracy 0.9574,Training Accuracy 0.9606182
    Iter 13,Testing Accuracy 0.9573,Training Accuracy 0.96107274
    Iter 14,Testing Accuracy 0.9587,Training Accuracy 0.9614546
    Iter 15,Testing Accuracy 0.9581,Training Accuracy 0.9616727
    Iter 16,Testing Accuracy 0.9599,Training Accuracy 0.96369094
    Iter 17,Testing Accuracy 0.9601,Training Accuracy 0.96403635
    Iter 18,Testing Accuracy 0.9618,Training Accuracy 0.9658909
    Iter 19,Testing Accuracy 0.9608,Training Accuracy 0.9652
    Iter 20,Testing Accuracy 0.9618,Training Accuracy 0.96607274
    Iter 21,Testing Accuracy 0.9634,Training Accuracy 0.96794546
    Iter 22,Testing Accuracy 0.9639,Training Accuracy 0.96836364
    Iter 23,Testing Accuracy 0.964,Training Accuracy 0.96965456
    Iter 24,Testing Accuracy 0.9644,Training Accuracy 0.9693091
    Iter 25,Testing Accuracy 0.9647,Training Accuracy 0.9703818
    Iter 26,Testing Accuracy 0.9639,Training Accuracy 0.9702
    Iter 27,Testing Accuracy 0.9651,Training Accuracy 0.9708909
    Iter 28,Testing Accuracy 0.9666,Training Accuracy 0.9711818
    Iter 29,Testing Accuracy 0.9644,Training Accuracy 0.9710364
    Iter 30,Testing Accuracy 0.9659,Training Accuracy 0.97205454
  • 相关阅读:
    今发现“最全前端资源汇集”,果断收藏
    js基础
    重排版与重绘
    小乌龟的配置
    考试网站
    苹果手机上时间的兼容
    自定义alert
    [概率dp] 流浪地球
    [权值线段树] 1163B2 Cat Party (Hard Edition)
    [单调栈]1156E Special Segments of Permutation
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11605482.html
Copyright © 2020-2023  润新知