import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data #download data mnist=input_data.read_data_sets('data/',one_hot=True) trainimg=mnist.train.images trainlabel=mnist.train.labels testimg=mnist.test.images print("downloading...") print("type:%s" % (type(mnist))) print("tain data size:%d" % (mnist.train.num_examples)) print("test data size:%d" % (mnist.test.num_examples)) print("tarin lable's shape: %s" % (trainlabel.shape,)) #show example # nsample = 5 # randidx=np.random.randint(trainimg.shape[0],size=nsample) # for i in randidx: # cur_img=np.reshape(trainimg[i,:],(28,28)) # cur_label=np.argmax(trainlabel[i,:]) # plt.matshow(cur_img) # print(""+str(i)+"th training data,"+"which label is:"+str(cur_label)) # plt.show() #batch batch_size=100 batch_xs,batch_ys=mnist.train.next_batch(batch_size)#x-data,y-label ####start train #1.set up numClasses=10 inputSize=784#28*28 trainningIterations=50000#total steps batchSize=64# #2.model #64:x(1*784)*w(784*10)+b1(10)=y(1*10) X=tf.placeholder(tf.float32,shape=[None,inputSize]) y=tf.placeholder(tf.float32,shape=[None,numClasses]) #2.1 initial W1 = tf.Variable(tf.zeros([784,10])) B1 = tf.Variable(tf.zeros([10])) #2.2 model set y_pred=tf.nn.softmax(tf.matmul(X,W1)+B1)#10*1 loss=tf.reduce_mean(tf.square(y-y_pred)) cross_entropy=-tf.reduce_sum(y*tf.log(y_pred)) opt=tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(cross_entropy) correct_prediction=tf.equal(tf.argmax(y_pred,1),tf.argmax(y,1))# accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float"))#bool 2 float #2.3 run train sess=tf.Session() init=tf.global_variables_initializer() sess.run(init) for i in range(trainningIterations): batch=mnist.train.next_batch(batch_size) batchInput=batch[0] batchLabels=batch[1] sess.run(opt,feed_dict={X:batchInput,y:batchLabels}) if i%1000 == 0: train_accuracy=sess.run(accuracy,feed_dict={X:batchInput,y:batchLabels}) print("step %d, tarinning accuracy %g" % (i,train_accuracy)) #2.4 run test to accuracy batch=mnist.test.next_batch(batch_size) testAccuracy=sess.run(accuracy,feed_dict={X:batch[0],y:batch[1]}) print("test accuracy %g" % (testAccuracy))
理论参考:
http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html