• MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)


    版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

    全连接神经网络是深度学习的基础,理解它就可以掌握深度学习的核心概念:前向传播、反向误差传递、权重、学习率等。这里先用python创建模型,用minist作为数据集进行训练。

    定义3层神经网络:输入层节点28*28(对应minist图片像素数)、隐藏层节点300、输出层节点10(对应0-9个数字)。

    网络的激活函数采用sigmoid,网络权重的初始化采用正态分布。

    完整代码如下:

      1 # -*- coding:utf-8 -*-
      2 
      3 u"""全连接神经网络训练学习MINIST"""
      4 
      5 __author__ = 'zhengbiqing 460356155@qq.com'
      6 
      7 
      8 import numpy
      9 import scipy.special
     10 import scipy.misc
     11 from PIL import Image
     12 import matplotlib.pyplot
     13 import pylab
     14 import datetime
     15 from random import shuffle
     16 
     17 
     18 #是否训练网络
     19 LEARN = True
     20 
     21 #是否保存网络
     22 SAVE_PARA = False
     23 
     24 #网络节点数
     25 INPUT = 784
     26 HIDDEN = 300
     27 OUTPUT = 10
     28 
     29 #学习率和训练次数
     30 LR = 0.05
     31 EPOCH = 10
     32 
     33 #训练数据集文件
     34 TRAIN_FILE = 'mnist_train.csv'
     35 TEST_FILE = 'mnist_test.csv'
     36 
     37 #网络保存文件名
     38 WEIGHT_IH = "minist_fc_wih.npy"
     39 WEIGHT_HO = "minist_fc_who.npy"
     40 
     41 
     42 #神经网络定义
     43 class NeuralNetwork:
     44     def __init__(self, inport_nodes, hidden_nodes, output_nodes, learnning_rate):
     45         #神经网络输入层、隐藏层、输出层节点数
     46         self.inodes = inport_nodes
     47         self.hnodes = hidden_nodes
     48         self.onodes = output_nodes
     49 
     50         #神经网络训练学习率
     51         self.learnning_rate = learnning_rate
     52 
     53         #用均值为0,标准方差为连接数的-0.5次方的正态分布初始化权重
     54         #权重矩阵行列分别为hidden * input、 output * hidden,和ih、ho相反
     55         self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
     56         self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
     57 
     58         #sigmoid函数为激活函数
     59         self.active_fun = lambda x: scipy.special.expit(x)        
     60 
     61     #设置神经网络权重,在加载已训练的权重时调用
     62     def set_weight(self, wih, who):
     63         self.wih = wih
     64         self.who = who
     65 
     66     #前向传播,根据输入得到输出
     67     def get_outputs(self, input_list):
     68         # 把list转换为N * 1的矩阵,ndmin=2二维,T转制
     69         inputs = numpy.array(input_list, ndmin=2).T
     70 
     71         # 隐藏层输入 = W dot X,矩阵乘法
     72         hidden_inputs = numpy.dot(self.wih, inputs)
     73         hidden_outputs = self.active_fun(hidden_inputs)
     74 
     75         final_inputs = numpy.dot(self.who, hidden_outputs)
     76         final_outputs = self.active_fun(final_inputs)
     77 
     78         return inputs, hidden_outputs, final_outputs
     79 
     80     #网络训练,误差计算,误差反向分配更新网络权重
     81     def train(self, input_list, target_list):
     82         inputs, hidden_outputs, final_outputs = self.get_outputs(input_list)
     83 
     84         targets = numpy.array(target_list, ndmin=2).T
     85 
     86         #误差计算
     87         output_errors = targets - final_outputs
     88         hidden_errors = numpy.dot(self.who.T, output_errors)
     89 
     90         #连接权重更新
     91         self.who += numpy.dot(self.learnning_rate * output_errors * final_outputs * (1 - final_outputs), hidden_outputs.T)
     92         self.wih += numpy.dot(self.learnning_rate * hidden_errors * hidden_outputs * (1 - hidden_outputs), inputs.T)
     93         
     94 
     95 #图像像素值变换
     96 def vals2input(vals):
     97     #[0,255]的图像像素值转换为i[0.01,1],以便sigmoid函数作非线性变换
     98     return (numpy.asfarray(vals) / 255.0 * 0.99) + 0.01
     99 
    100 
    101 '''
    102 训练网络
    103 train:是否训练网络,如果不训练则直接加载已训练得到的网络权重
    104 epoch:训练次数
    105 save:是否保存训练结果,即网络权重
    106 '''
    107 def net_train(train, epochs, save):
    108     if train:
    109         with open(TRAIN_FILE, 'r') as train_file:
    110             train_list = train_file.readlines()
    111 
    112         for epoch in range(epochs):
    113             #打乱训练数据
    114             shuffle(train_list)
    115 
    116             for data in train_list:
    117                 all_vals = data.split(',')
    118                 #图像数据为0~255,转换到0.01~1区间,以便激活函数更有效
    119                 inputs = vals2input(all_vals[1:])
    120 
    121                 #标签,正确的为0.99,其他为0.01
    122                 targets = numpy.zeros(OUTPUT) + 0.01
    123                 targets[int(all_vals[0])] = 0.99
    124 
    125                 net.train(inputs, targets)
    126 
    127             #每个epoch结束后用测试集检查识别准确度
    128             net_test(epoch)
    129             print('')
    130 
    131         if save:
    132             #保存连接权重
    133             numpy.save(WEIGHT_IH, net.wih)
    134             numpy.save(WEIGHT_HO, net.who)
    135     else:
    136         #不训练直接加载已保存的权重
    137         wih = numpy.load(WEIGHT_IH)
    138         who = numpy.load(WEIGHT_HO)
    139         net.set_weight(wih, who)
    140 
    141 
    142 '''
    143 用测试集检查准确率
    144 '''
    145 def net_test(epoch):
    146     with open(TEST_FILE, 'r') as test_file:
    147         test_list = test_file.readlines()
    148 
    149     ok = 0
    150     errlist = [0] * 10
    151 
    152     for data in test_list:
    153         all_vals = data.split(',')
    154         inputs = vals2input(all_vals[1:])
    155         _, _, net_out = net.get_outputs(inputs)
    156 
    157         max = numpy.argmax(net_out)
    158         if max == int(all_vals[0]):
    159             ok += 1
    160         else:
    161             # 识别错误统计,每个数字识别错误计数
    162             # print('target:', all_vals[0], 'net_out:', max)
    163             errlist[int(all_vals[0])] += 1
    164 
    165     print('EPOCH: {epoch} score: {score}'.format(epoch=epoch, score = ok / len(test_list) * 100))
    166     print('error list: ', errlist, ' total: ', sum(errlist))
    167 
    168 
    169 #变换图片的尺寸,保存变换后的图片
    170 def resize_img(filein, fileout, width, height, type):
    171     img = Image.open(filein)
    172     out = img.resize((width, height), Image.ANTIALIAS)
    173     out.save(fileout, type)
    174 
    175 
    176 #用训练得到的网络识别一个图片文件
    177 def img_test(img_file):
    178     file_name_list = img_file.split('.')
    179     file_name, file_type = file_name_list[0], file_name_list[1]
    180     out_file = file_name + 'out' + '.' + file_type
    181     resize_img(img_file, out_file, 28, 28, file_type)
    182 
    183     img_array = scipy.misc.imread(out_file, flatten=True)
    184     img_data = 255.0 - img_array.reshape(784)
    185     img_data = (img_data / 255.0 * 0.99) + 0.01
    186 
    187     _, _, net_out = net.get_outputs(img_data)
    188     max = numpy.argmax(net_out)
    189     print('pic recognized as: ', max)
    190 
    191 
    192 #显示数据集某个索引对应的图片
    193 def img_show(train, index):
    194     file = TRAIN_FILE if train else TEST_FILE
    195     with open(file, 'r') as test_file:
    196         test_list = test_file.readlines()
    197 
    198     all_values = test_list[index].split(',')
    199     print('number is: ', all_values[0])
    200 
    201     image_array = numpy.asfarray(all_values[1:]).reshape((28, 28))
    202     matplotlib.pyplot.imshow(image_array, cmap='Greys', interpolation='None')
    203     pylab.show()
    204 
    205 
    206 start_time = datetime.datetime.now()
    207 
    208 net = NeuralNetwork(INPUT, HIDDEN, OUTPUT, LR)
    209 net_train(LEARN, EPOCH, SAVE_PARA)
    210 
    211 if not LEARN:
    212     net_test(0)
    213 else:
    214     print('MINIST FC Train:', INPUT, HIDDEN, OUTPUT, 'LR:', LR, 'EPOCH:', EPOCH)
    215     print('train spend time: ', datetime.datetime.now() - start_time)
    216 
    217 #用画图软件创建图片文件,由得到的网络进行识别
    218 # img_test('t9.png')
    219 
    220 #显示minist中的某个图片
    221 # img_show(True, 1)

    784-300-10简单的全连接神经网络训练结果准确率基本在97.7%左右,运行结果如下:

    EPOCH: 0 score: 95.96000000000001
    error list:  [13, 21, 31, 28, 51, 61, 33, 66, 44, 56]  total:  404

    EPOCH: 1 score: 96.77
    error list:  [15, 19, 27, 63, 37, 37, 21, 40, 18, 46]  total:  323

    EPOCH: 2 score: 97.25
    error list:  [9, 17, 26, 26, 24, 56, 21, 41, 22, 33]  total:  275

    EPOCH: 3 score: 97.82
    error list:  [9, 16, 21, 18, 20, 18, 22, 21, 31, 42]  total:  218

    EPOCH: 4 score: 97.54
    error list:  [12, 23, 17, 25, 15, 34, 19, 25, 22, 54]  total:  246

    EPOCH: 5 score: 97.78999999999999
    error list:  [10, 16, 20, 23, 21, 32, 18, 31, 26, 24]  total:  221

    EPOCH: 6 score: 97.6
    error list:  [9, 13, 26, 34, 27, 26, 20, 28, 22, 35]  total:  240

    EPOCH: 7 score: 97.74000000000001
    error list:  [12, 8, 26, 29, 27, 26, 25, 20, 27, 26]  total:  226

    EPOCH: 8 score: 97.77
    error list:  [7, 10, 27, 16, 29, 28, 23, 29, 26, 28]  total:  223

    EPOCH: 9 score: 97.99
    error list:  [11, 10, 32, 17, 18, 24, 14, 22, 21, 32]  total:  201

    MINIST FC Train: 784 300 10 LR: 0.05 EPOCH: 10
    train spend time:  0:05:54.137925

    Process finished with exit code 0

  • 相关阅读:
    jquery实现动态五角星评分
    jquery实现动态五角星评分
    三个水桶(看了三遍,想了五遍!)
    三个水桶(看了三遍,想了五遍!)
    三个水桶(看了三遍,想了五遍!)
    复制一个5G文件只需要两秒,全网最牛方法!
    复制一个5G文件只需要两秒,全网最牛方法!
    Symmetric Multiprocessor Organization
    smaller programs should improve performance RISC(精简指令集计算机)和CISC(复杂指令集计算机)是当前CPU的两种架构 区别示例
    mysqli_multi_query($link, $sql_w);
  • 原文地址:https://www.cnblogs.com/zhengbiqing/p/10407118.html
Copyright © 2020-2023  润新知