• tensorflow 全连接神经网络识别mnist数据


        之前没有学过tensorflow,所以使用tensorflow来对mnist数据进行识别,采用最简单的全连接神经网络,第一层是784,(输入层),隐含层是256,输出层是10

    ,相关注释卸载程序中。

      1 #!/usr/bin/env python 3.6
      2 #_*_coding:utf-8 _*_
      3 #@Time    :2020/2/12 15:34
      4 #@Author  :hujinzhou 
      5 #@FileName: mnist.py
      6 
      7 #@Software: PyCharm
      8 import tensorflow as tf
      9 import tensorflow.examples.tutorials.mnist.input_data as input_data
     10 import matplotlib.pyplot as plt
     11 import numpy as np
     12 from time import time
     13 mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)#通过tensorflow下载mnist数据集
     14 """图片的显示"""
     15 def plot_image(image):
     16     plt.imshow(image.reshape(28,28),cmap='binary')#tensorflow中的数据是将图片平铺成一列的存储,
     17                                                   # 所以显示的时候应该reshape成28*28
     18     plt.show()
     19 """查看多项数训练数据images与labels"""
     20 def plot_images_labels_prediction(images,labels,prediction,idx,num):#idx表示要显示的第idx个图像从idx~idx+25
     21     fig=plt.gcf()
     22     fig.set_size_inches(25,25)#设置显示尺寸
     23     if num>25:num=25
     24     for i in range(0,num):
     25         ax=plt.subplot(5,5,i+1)#一次显示多个子图
     26         ax.imshow(np.reshape(images[idx],(28,28)),cmap='binary')#将第idx个图像数据reshape成28*28的numpy并显示
     27         title="label="+str(np.argmax(labels[idx]))#设置图像的title,将onehot码转为数值码
     28         """如果有预测的prediction,则重新写title"""
     29         if len(prediction)>0:
     30             title+=",predict="+str(prediction[idx])
     31         ax.set_title(title,fontsize=10)
     32         ax.set_xticks([]);ax.set_yticks([])#设置xy轴为空,如果不设置则会有标度(像素值)
     33         idx+=1
     34     plt.show()
     35 
     36 
     37 """构造多层感知机"""
     38 """自己构造感知机"""
     39 # def layer(output_dim, input_dim, inputs, activation=None):
     40 #     W = tf.Variable(tf.random_normal([input_dim, output_dim]))
     41 #     b = tf.Variable(tf.random_normal([1, output_dim]))
     42 #     XWb = tf.matmul(inputs, W) + b
     43 #     if activation is None:
     44 #         outputs = XWb
     45 #     else:
     46 #         outputs = activation(XWb)
     47 #     return outputs
     48 
     49 """采用tf包来构造感知机"""
     50 x = tf.placeholder("float", [None, 784])
     51 h1=tf.layers.dense(inputs=x,units=256,activation=tf.nn.relu)
     52 # h1 = layer(output_dim=256, input_dim=784,
     53 #            inputs=x, activation=tf.nn.relu)
     54 y_predict = tf.layers.dense(inputs=h1,units=10,activation=None)
     55 y_label = tf.placeholder("float", [None, 10])
     56 loss_function = tf.reduce_mean(
     57     tf.nn.softmax_cross_entropy_with_logits_v2
     58     (logits=y_predict,
     59      labels=y_label))#计算损失值
     60 optimizer = tf.train.AdamOptimizer(learning_rate=0.001) 
     61     .minimize(loss_function)#使用优化器反向传播,使得损失量为最小
     62 correct_prediction = tf.equal(tf.argmax(y_label, 1),
     63                               tf.argmax(y_predict, 1))#相等为1,不想等为0,统计正确的个数
     64 accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))#精度等于正确个数除以总数
     65 """训练过程"""
     66 train_epoch=30
     67 batch_size=100
     68 loss_list=[];epoch_list=[];accuracy_list=[]
     69 starttime=time()
     70 
     71 
     72 sess=tf.Session()
     73 sess.run(tf.global_variables_initializer())
     74 for epoch in range(train_epoch):
     75     for i in range(550):
     76         batch_x, batch_y = mnist.train.next_batch(batch_size)
     77         sess.run(optimizer, feed_dict={x: batch_x, y_label: batch_y})#使用55000的训练集进行优化
     78 
     79     loss, acc = sess.run([loss_function, accuracy],
     80                          feed_dict={x: mnist.validation.images,
     81                                     y_label: mnist.validation.labels})#验证集进行验证
     82 
     83     epoch_list.append(epoch);
     84     loss_list.append(loss)
     85     accuracy_list.append(acc)
     86     print("Train Epoch:", '%02d' % (epoch + 1), "Loss=", 
     87           "{:.9f}".format(loss), " Accuracy=", acc)
     88 duration = time() - starttime
     89 print("The process has taken;{:.10f}".format(duration))
     90 fig2=plt.gcf()
     91 fig2.set_size_inches(4,2)#设置显示尺寸
     92 plt.plot(epoch_list,loss_list,label="loss")
     93 
     94 plt.ylabel('loss')
     95 plt.xlabel('epoch')
     96 plt.legend(['loss'],loc='upper left')
     97 plt.show()
     98 plt.plot(epoch_list,accuracy_list,label='acc')
     99 plt.show()
    100 # sess=tf.Session()
    101 # init = tf.global_variables_initializer()
    102 # sess.run(init)
    103 #注意这个地方,不可以重新设置sess,不可以重新开启回话,重新开启会错误
    104 print("acc:",sess.run(accuracy,feed_dict={x:mnist.test.images,y_label:mnist.test.labels}))
    105 
    106 pre_result=sess.run(tf.argmax(y_predict,1),feed_dict={x:mnist.test.images})
    107 plot_images_labels_prediction(mnist.test.images,mnist.test.labels,pre_result,0,25)
    108 sess.close()

     

     

  • 相关阅读:
    .NET 动态脚本语言
    webParts与Web部件
    比较JqGrid与XtraGrid
    XtraGrid滚轮翻页
    Python------继承
    Python 私有化类的属性
    Python print 输出不换行,只有空格
    Python--函数参数类型
    手推FP-growth (频繁模式增长)算法------挖掘频繁项集
    Python 返回多个值+Lambda的使用
  • 原文地址:https://www.cnblogs.com/hujinzhou/p/guobao_2020_2_13.html
Copyright © 2020-2023  润新知