import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据
mnist = input_data.read_data_sets("MNIST_data",one_hot = True)
#定义每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples//batch_size
#定义2个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32) #表示有百分之多少的神经元工作
#神经网络:
#正态分布,方差0.1
W1 = tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
b1 = tf.Variable(tf.zeros([2000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop_out = tf.nn.dropout(L1,keep_prob)
W2 = tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
b2 = tf.Variable(tf.zeros([2000])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop_out,W2)+b2)
L2_drop_out = tf.nn.dropout(L2,keep_prob)
W3 = tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
b3 = tf.Variable(tf.zeros([1000])+0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop_out,W3)+b3)
L3_drop_out = tf.nn.dropout(L3,keep_prob)
W4 = tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
b4 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop_out,W4)+b4)
#二次代价函数:
# loss = tf.reduce_mean(tf.square(y-prediction))
#对数似然函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels= y,
logits= prediction))
#梯度下降
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#求准确率
#比较预测值最大标签位置与真实值最大标签位置是否相等
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准去率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict = {x:batch_xs,
y:batch_ys,
keep_prob:0.7})
test_acc = sess.run(accuracy,feed_dict ={x:mnist.test.images,
y:mnist.test.labels,
keep_prob:1.0})
train_acc = sess.run(accuracy,feed_dict ={x:mnist.train.images,
y:mnist.train.labels,
keep_prob:1.0})
print("Iter"+str(epoch+1)+",Testing accuracy-"+str(test_acc)
+",Train accuracy-"+str(train_acc))
keep_prob = 1.0时
Iter1,Testing accuracy-0.944,Train accuracy-0.958545
Iter2,Testing accuracy-0.958,Train accuracy-0.974691
Iter3,Testing accuracy-0.9621,Train accuracy-0.982855
Iter4,Testing accuracy-0.9652,Train accuracy-0.986455
Iter5,Testing accuracy-0.968,Train accuracy-0.988364
Iter6,Testing accuracy-0.9683,Train accuracy-0.989855
Iter7,Testing accuracy-0.9694,Train accuracy-0.990982
Iter8,Testing accuracy-0.9687,Train accuracy-0.991636
Iter9,Testing accuracy-0.9691,Train accuracy-0.992255
Iter10,Testing accuracy-0.9697,Train accuracy-0.9926
Iter11,Testing accuracy-0.9697,Train accuracy-0.992909
Iter12,Testing accuracy-0.9705,Train accuracy-0.993236
Iter13,Testing accuracy-0.9701,Train accuracy-0.993309
Iter14,Testing accuracy-0.9707,Train accuracy-0.993527
Iter15,Testing accuracy-0.9705,Train accuracy-0.993691
Iter16,Testing accuracy-0.9709,Train accuracy-0.993891
Iter17,Testing accuracy-0.9707,Train accuracy-0.993982
Iter18,Testing accuracy-0.9716,Train accuracy-0.994036
Iter19,Testing accuracy-0.9717,Train accuracy-0.994236
Iter20,Testing accuracy-0.9722,Train accuracy-0.994364
Iter21,Testing accuracy-0.9716,Train accuracy-0.994436
Iter22,Testing accuracy-0.972,Train accuracy-0.994509
Iter23,Testing accuracy-0.9722,Train accuracy-0.9946
Iter24,Testing accuracy-0.972,Train accuracy-0.994636
Iter25,Testing accuracy-0.9723,Train accuracy-0.994709
Iter26,Testing accuracy-0.9723,Train accuracy-0.994836
Iter27,Testing accuracy-0.9722,Train accuracy-0.994891
Iter28,Testing accuracy-0.9727,Train accuracy-0.994964
Iter29,Testing accuracy-0.9724,Train accuracy-0.995091
Iter30,Testing accuracy-0.9725,Train accuracy-0.995164
Iter31,Testing accuracy-0.9725,Train accuracy-0.995182
keep_prob = 0.7时
Iter1,Testing accuracy-0.9187,Train accuracy-0.912709
Iter2,Testing accuracy-0.9281,Train accuracy-0.923782
Iter3,Testing accuracy-0.9357,Train accuracy-0.935236
Iter4,Testing accuracy-0.9379,Train accuracy-0.940855
Iter5,Testing accuracy-0.9441,Train accuracy-0.944564
Iter6,Testing accuracy-0.9463,Train accuracy-0.948164
Iter7,Testing accuracy-0.9472,Train accuracy-0.950182
Iter8,Testing accuracy-0.9515,Train accuracy-0.9544
Iter9,Testing accuracy-0.9548,Train accuracy-0.956455
Iter10,Testing accuracy-0.9551,Train accuracy-0.959091
Iter11,Testing accuracy-0.9566,Train accuracy-0.959891
Iter12,Testing accuracy-0.9594,Train accuracy-0.962036
Iter13,Testing accuracy-0.9592,Train accuracy-0.964236
Iter14,Testing accuracy-0.9585,Train accuracy-0.964818
Iter15,Testing accuracy-0.9607,Train accuracy-0.966
Iter16,Testing accuracy-0.961,Train accuracy-0.9668
Iter17,Testing accuracy-0.9612,Train accuracy-0.967891
Iter18,Testing accuracy-0.9643,Train accuracy-0.969236
Iter19,Testing accuracy-0.9646,Train accuracy-0.969945
Iter20,Testing accuracy-0.9655,Train accuracy-0.970909
Iter21,Testing accuracy-0.9656,Train accuracy-0.971509
Iter22,Testing accuracy-0.9668,Train accuracy-0.972891
Iter23,Testing accuracy-0.9665,Train accuracy-0.972982
Iter24,Testing accuracy-0.9687,Train accuracy-0.974091
Iter25,Testing accuracy-0.9673,Train accuracy-0.974782
Iter26,Testing accuracy-0.9682,Train accuracy-0.975127
Iter27,Testing accuracy-0.9682,Train accuracy-0.976055
Iter28,Testing accuracy-0.9703,Train accuracy-0.976582
Iter29,Testing accuracy-0.9692,Train accuracy-0.976982
Iter30,Testing accuracy-0.9707,Train accuracy-0.977891
Iter31,Testing accuracy-0.9703,Train accuracy-0.978091
可以看到,当drop-out了30%的时候,训练集与测试集的准确率相差比全连接要小一些,可以防止过拟合的情况出现。