3层神经网络,自定义输入节点、隐藏层、输出节点的个数,使用sigmoid函数作为激活函数,梯度下降法进行权重的优化。
使用MNIST数据集,进行手写数字识别
1 #!/usr/bin/env python 2 # -*- coding:utf-8 -*- 3 4 #!/usr/bin/env python 5 # -*- coding:utf-8 -*- 6 7 import numpy 8 import scipy.special 9 10 11 #手写数字识别神经网络 12 class NeuralNetwork(): 13 def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate): 14 ''' 15 神经网络初始化 16 :param inputnodes: 输入节点的数量 17 :param hiddennodes: 隐藏层节点的数量 18 :param outputnodes: 输出节点的数量 19 :param learningrate: 学习率 20 :return: 21 ''' 22 self.inodes = inputnodes 23 self.hnodes = hiddennodes 24 self.onodes = outputnodes 25 self.learn = learningrate 26 self.wih = numpy.random.rand(self.hnodes,self.inodes) - 0.5 27 self.who = numpy.random.rand(self.onodes,self.hnodes) - 0.5 28 # self.wih = numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.inodes,self.inodes)) 29 # self.who = numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.hnodes,self.hnodes)) 30 self.activate_function = lambda x : scipy.special.expit(x) 31 # print(self.who) 32 # print(self.wih) 33 def train(self,input_list,target_list): 34 ''' 35 训练神经网络首先计算样本输出,然后在与目标值进行对比,更新权重 36 :param input_list: 输入值 37 :param target_list: 目标值 38 :return: 39 ''' 40 #针对样本计算输出,与query函数一样 41 inputs = numpy.array(input_list).T 42 targets = numpy.array(target_list).T 43 hidden_inputs = numpy.dot(self.wih,inputs) 44 hidden_outputs = self.activate_function(hidden_inputs) 45 final_inputs = numpy.dot(self.who,hidden_outputs) 46 final_outpust = self.activate_function(final_inputs) 47 48 #将计算得到的输出与目标值对比,更新权重 49 output_error = targets - final_outpust 50 hidden_error = numpy.dot(self.who.T,output_error) 51 52 # print(output_error.shape) 53 # print(final_outpust.shape) 54 # print(hidden_outputs.T.shape) 55 # self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)),numpy.transpose(hidden_outputs)) 56 # self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs)) 57 58 self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)).reshape((self.onodes,1)),hidden_outputs.reshape((1,self.hnodes))) 59 self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)).reshape((self.hnodes,1)),inputs.reshape((1,self.inodes))) 60 61 62 63 def query(self,input_list): 64 ''' 65 计算输出 66 :param input_list: 67 :return: 68 ''' 69 inputs = numpy.array(input_list).T 70 hidden_inputs = numpy.dot(self.wih,inputs) 71 hidden_outputs = self.activate_function(hidden_inputs) 72 final_inputs = numpy.dot(self.who,hidden_outputs) 73 final_outpust = self.activate_function(final_inputs) 74 75 return final_outpust 76 77 #初始化一个神经网络对象 78 n = NeuralNetwork(784,100,10,0.5) 79 80 #训练数据 81 with open('dataset/mnist_train.csv','r') as f: 82 train_data = f.readlines() 83 84 #训练神经网络 85 for line in train_data: 86 data = line.split(',') 87 inputs = (numpy.asfarray(data[1:]) / 255 * 0.99) + 0.01 88 targets = numpy.zeros(n.onodes)+0.01 89 targets[int(data[0])] = 0.99 90 91 n.train(inputs,targets) 92 93 94 #测试神经网络 95 with open('dataset/mnist_test_10.csv','r') as f: 96 test_data = f.readlines() 97 98 for line in test_data: 99 label = int(line[0]) 100 data = line.split(',') 101 input_list = numpy.asfarray(data[1:]) 102 output = n.query(input_list) 103 104 print(label) 105 print(output)
代码实现了手写数字的识别,可以在此基础上,进行改进研究,比如调节学习率、初始化权重的方式,激活函数等变化时对结果的影响。