这篇用一种更简洁的方式写CNN神经网络
上代码:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True) #define conv layer def conv_layer(inputs, channels_in, channels_out): w = tf.Variable(tf.truncated_normal([5,5,channels_in, channels_out], stddev=0.1)) b = tf.Variable(tf.constant(0.1, shape=[channels_out])) conv = tf.nn.conv2d(inputs, w, strides=[1, 1, 1, 1], padding='SAME') #调用conv2d act = tf.nn.relu(conv + b) # conv + bias return act #define fully connection layer def fc_layer(inputs, channels_in, channels_out, activate_function=None): w = tf.Variable(tf.truncated_normal([channels_in, channels_out])) b = tf.Variable(tf.zeros([channels_out])+0.1) fc = tf.matmul(inputs, w) + b if activate_function is None: res = fc else: res = activate_function(fc) return res x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) x_image = tf.reshape(x, [-1, 28, 28, 1]) keep_prob = tf.placeholder(tf.float32) conv1 = conv_layer(x_image, 1, 32) pool1 = tf.nn.max_pool(conv1, ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], padding='SAME') # 14*14*32 conv2 = conv_layer(pool1, 32, 64) pool2 = tf.nn.max_pool(conv2, ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], padding='SAME') flattended = tf.reshape(pool2, [-1, 7*7*64]) fc1 = fc_layer(flattended, 7 * 7 * 64, 1024, tf.nn.relu) fc1 = tf.nn.dropout(fc1, keep_prob) logits = fc_layer(fc1, 1024, 10) cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y) ) train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) for i in range(10000): #batch = mnist.train.next_batch(100) images, labels = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x:images, y:labels,keep_prob:1}) if i%50==0: print(sess.run(accuracy, feed_dict={x:mnist.test.images[:1000], y:mnist.test.labels[:1000], keep_prob:1})) #print(sess.run(accuracy, feed_dict={x:batch[0], y:batch[1]})) sess.close()
1.这段代码封装了CNN的每一层
2.这段代码还用到了droup out 防止过拟合
3.不知道为什么,这样的写法似乎比不封装要慢,最终准确率也无法到达98 99