搭建一个简单的神经网络对mnist数据集中的手写数字数据集进行训练和测试。
输入的每张数据包含784个像素点,第一层为784行256列的矩阵,第二层是256行128列的矩阵,输出层则将结果转换为10个输出值,代表手写数字的10种分类结果,每层有一个权重值weight和偏置bias
代码实现
#搭建两层的神经网络 import numpy as np import tensorflow.compat.v1 as tf tf.disable_v2_behavior() import matplotlib.pyplot as plt import input_data #加载数据集 minst=input_data.read_data_sets('data/data/', one_hot=True) #设置参数 n_hidden_1=256 #第一层输出 n_hidden_2=128 #第二层输出 n_input=784 #输入像素点 n_classes=10 #分类结果 #输入输出 x=tf.placeholder("float",[None,n_input]) y=tf.placeholder("float",[None,n_classes]) #神经网络参数 stddev=0.1 weights={ 'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev=stddev)), #高斯初始化 'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)), 'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev)) } biases={ 'b1':tf.Variable(tf.random_normal([n_hidden_1])), 'b2':tf.Variable(tf.random_normal([n_hidden_2])), 'out':tf.Variable(tf.random_normal([n_classes])) } print('NetWork Ready') def multilayer_perceptron(_X,_weights,_biases): layer_1=tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),biases['b1'])) #sigmoid激活函数 layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,_weights['w2']),_biases['b2'])) return (tf.matmul(layer_2,_weights['out'])+_biases['out']) #预测 pred=multilayer_perceptron(x,weights,biases) #损失和优化器 learning_rate=0.01 cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(y,pred)) optm=tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost) corr=tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) accr=tf.reduce_mean(tf.cast(corr,"float")) #初始化 init=tf.global_variables_initializer() print('function ready') #训练参数 training_epochs=20 batch_size=100 display_step=4 #开始训练 sess=tf.Session() sess.run(init) feeds={} for epoch in range(training_epochs): avg_cost=0. total_batch=int(minst.train.num_examples/batch_size) for i in range(total_batch): batch_xs,batch_ys=minst.train.next_batch(batch_size) feeds={x:batch_xs,y:batch_ys} sess.run(optm,feed_dict=feeds) avg_cost+=sess.run(cost,feed_dict=feeds) avg_cost=avg_cost/total_batch if(epoch+1)%display_step==0: print('Epoch:%03d/%03d cost:%.9f'%(epoch+1,training_epochs,avg_cost)) train_acc=sess.run(accr,feed_dict=feeds) print('TRAIN ACCURACY:%.3f'%(train_acc)) feeds={x:minst.test.images,y:minst.test.labels} test_acc=sess.run(accr,feed_dict=feeds) print('TEST ACCURACY:%.3f'%(test_acc)) print('训练完成')
训练结果