• 神经网络-手写字体识别


    3层神经网络,自定义输入节点、隐藏层、输出节点的个数,使用sigmoid函数作为激活函数,梯度下降法进行权重的优化。

    使用MNIST数据集,进行手写数字识别

      1 #!/usr/bin/env python
      2 # -*- coding:utf-8 -*-
      3 
      4 #!/usr/bin/env python
      5 # -*- coding:utf-8 -*-
      6 
      7 import numpy
      8 import scipy.special
      9 
     10 
     11 #手写数字识别神经网络
     12 class NeuralNetwork():
     13     def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):
     14         '''
     15         神经网络初始化
     16         :param inputnodes: 输入节点的数量
     17         :param hiddennodes: 隐藏层节点的数量
     18         :param outputnodes: 输出节点的数量
     19         :param learningrate: 学习率
     20         :return:
     21         '''
     22         self.inodes = inputnodes
     23         self.hnodes = hiddennodes
     24         self.onodes = outputnodes
     25         self.learn = learningrate
     26         self.wih = numpy.random.rand(self.hnodes,self.inodes) - 0.5
     27         self.who = numpy.random.rand(self.onodes,self.hnodes) - 0.5
     28         # self.wih = numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.inodes,self.inodes))
     29         # self.who = numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.hnodes,self.hnodes))
     30         self.activate_function = lambda x : scipy.special.expit(x)
     31         # print(self.who)
     32         # print(self.wih)
     33     def train(self,input_list,target_list):
     34         '''
     35         训练神经网络首先计算样本输出,然后在与目标值进行对比,更新权重
     36         :param input_list: 输入值
     37         :param target_list: 目标值
     38         :return:
     39         '''
     40         #针对样本计算输出,与query函数一样
     41         inputs = numpy.array(input_list).T
     42         targets = numpy.array(target_list).T
     43         hidden_inputs = numpy.dot(self.wih,inputs)
     44         hidden_outputs = self.activate_function(hidden_inputs)
     45         final_inputs = numpy.dot(self.who,hidden_outputs)
     46         final_outpust = self.activate_function(final_inputs)
     47 
     48         #将计算得到的输出与目标值对比,更新权重
     49         output_error = targets - final_outpust
     50         hidden_error = numpy.dot(self.who.T,output_error)
     51 
     52         # print(output_error.shape)
     53         # print(final_outpust.shape)
     54         # print(hidden_outputs.T.shape)
     55         # self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)),numpy.transpose(hidden_outputs))
     56         # self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs))
     57 
     58         self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)).reshape((self.onodes,1)),hidden_outputs.reshape((1,self.hnodes)))
     59         self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)).reshape((self.hnodes,1)),inputs.reshape((1,self.inodes)))
     60 
     61 
     62 
     63     def query(self,input_list):
     64         '''
     65         计算输出
     66         :param input_list:
     67         :return:
     68         '''
     69         inputs = numpy.array(input_list).T
     70         hidden_inputs = numpy.dot(self.wih,inputs)
     71         hidden_outputs = self.activate_function(hidden_inputs)
     72         final_inputs = numpy.dot(self.who,hidden_outputs)
     73         final_outpust = self.activate_function(final_inputs)
     74 
     75         return final_outpust
     76 
     77 #初始化一个神经网络对象
     78 n = NeuralNetwork(784,100,10,0.5)
     79 
     80 #训练数据
     81 with open('dataset/mnist_train.csv','r') as f:
     82     train_data = f.readlines()
     83 
     84 #训练神经网络
     85 for line in train_data:
     86     data = line.split(',')
     87     inputs = (numpy.asfarray(data[1:]) / 255 * 0.99) + 0.01
     88     targets = numpy.zeros(n.onodes)+0.01
     89     targets[int(data[0])] = 0.99
     90 
     91     n.train(inputs,targets)
     92 
     93 
     94 #测试神经网络
     95 with open('dataset/mnist_test_10.csv','r') as f:
     96     test_data = f.readlines()
     97 
     98 for line in test_data:
     99     label = int(line[0])
    100     data = line.split(',')
    101     input_list = numpy.asfarray(data[1:])
    102     output = n.query(input_list)
    103 
    104     print(label)
    105     print(output)

    代码实现了手写数字的识别,可以在此基础上,进行改进研究,比如调节学习率、初始化权重的方式,激活函数等变化时对结果的影响。

    三样东西有助于缓解生命的疲劳:希望、睡眠和微笑。---康德
  • 相关阅读:
    【转】微服务架构模式简介
    大话微服务
    Howto: 在ArcGIS10中将地图文档(mxd文档)批量保存到之前版本
    在Google Maps中导出KML文件
    ASP.NET(c#) 日期选择控件的另一种实现方法
    asp.net中的时间日期选择控件
    JAVA实现Excel导入/导出【转】
    将Gridview中的数据出到excel或word中
    asp.net导出excel并弹出保存提示框
    在ASP.NET中将GridView数据导出到Word、Excel
  • 原文地址:https://www.cnblogs.com/ronghe/p/10199972.html
Copyright © 2020-2023  润新知