本项目参考:
https://www.bilibili.com/video/av31500120?t=4657
训练代码
1 # coding: utf-8 2 # Learning from Mofan and Mike G 3 # Recreated by Paprikatree 4 # Convolution NN Train 5 6 import numpy as np 7 from keras.datasets import mnist 8 from keras.utils import np_utils 9 from keras.models import Sequential 10 from keras.layers import Convolution2D, Activation, MaxPool2D, Flatten, Dense 11 from keras.optimizers import Adam 12 from keras.models import load_model 13 14 15 nb_class = 10 16 nb_epoch = 4 17 batchsize = 128 18 19 ''' 20 1st,准备参数 21 X_train: (0,255) --> (0,1) CNN中似乎没有必要?cnn自动转了吗? 22 设置时间函数测试一下两者对比。 23 小技巧:X_train /= 255.0 就可不用转换成浮点了??? 24 ''' 25 # Preparing your data mnist. MAC /.keras/datasets linux home ./keras/datasets 26 (X_train, Y_train), (X_test, Y_test) = mnist.load_data() 27 28 29 # setup data shape 30 # (-1, 28, 28, 1) -1表示有默认个数据集,28*28是像素,1是1个通道 31 X_train = X_train.reshape(-1, 28, 28, 1) # tensorflow-channel last,while theano-channel first 32 X_test = X_test.reshape(-1, 28, 28, 1) 33 34 X_train = X_train/255.000 35 X_test = X_test/255.000 36 37 # One-hot 6 --> [0,0,0,0,0,1,0,0,0] 38 Y_train = np_utils.to_categorical(Y_train, nb_class) 39 Y_test = np_utils.to_categorical(Y_test, nb_class) 40 41 ''' 42 2nd,设置模型 43 ''' 44 45 # setup model 46 model = Sequential() 47 48 # 1st convolution layer # 滤波器要在28x28的图上横着走32次 49 model.add(Convolution2D( 50 filters=32, # 此处把filters写成了filter,找了半天。囧 51 kernel_size=[5, 5], # 滤波器是5x5大小的,可以是list列表,也可以是tuple元祖 52 padding='same', # padding也是一个窗口模式 53 input_shape=(28, 28, 1) # 定义输入的数据,必须是元组 54 )) 55 model.add(Activation('relu')) 56 model.add(MaxPool2D( 57 pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。 58 strides=(2, 2), # 相当于把图片缩小了。 59 padding="same", 60 )) 61 62 # 2nd Conv2D layer 63 model.add(Convolution2D( 64 filters=64, 65 kernel_size=(5, 5), 66 padding='same', 67 )) 68 model.add(Activation('relu')) 69 model.add(MaxPool2D( 70 pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。 71 strides=(2, 2), # 相当于把图片缩小了。 72 padding="same", 73 )) # 讨论,卷积层数和最终结果关系。 74 75 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容 76 model.add(Flatten()) # 把卷积层里面的全部转换层一维数组 77 model.add(Dense(1024)) # Dense is output 78 model.add(Activation('relu')) 79 80 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容 81 # 把卷积层里面的全部转换层一维数组 82 model.add(Dense(256)) # Dense is output 83 model.add(Activation('tanh')) 84 85 # 2nd Fully connected Dense 86 model.add(Dense(10)) 87 model.add(Activation('softmax')) 88 89 ''' 90 3rd 定义参数 91 ''' 92 # Define Optimizer and setup Param 93 adam = Adam(lr=0.0001) # Adam实例化 94 95 # compile model 96 model.compile( 97 optimizer=adam, # optimizer='Adam'也是可以的,且默认lr=0.001,此处已经实例化为adam 98 loss='categorical_crossentropy', 99 metrics=['accuracy'], 100 ) 101 102 # Run network 103 model.fit(x=X_train, # 更多参数可以查看fit函数,alt+鼠标左键单击fit 104 y=Y_train, 105 epochs=nb_epoch, 106 batch_size=batchsize, # p=parameter, batch_size; v=var, batch size 107 verbose=1, # 显示模式 108 validation_data=(X_test, Y_test) 109 ) 110 model.save('model_name.h5') 111 # evaluation = model.evaluate(X_test, Y_test) 现在用model.fit(validation_data) 112 # print(evaluation) 效果一样
测试代码:
1 # coding: utf-8 2 # Learning from Mofan and Mike G 3 # Recreated by Paprikatree 4 # Convolution NN Predict 5 6 import numpy as np 7 from keras.models import load_model # ?? 8 import matplotlib.pyplot as plt 9 import matplotlib.image as processimage 10 11 12 # load trained model 13 model = load_model('model_name.h5') # 已经训练好了的模型,在根目录下,默认为model_name.h5 14 15 16 # 写一个来预测的类 17 class MainPredictImg(object): 18 19 def __init__(self): 20 pass 21 22 def pred(self, filename): 23 pred_img = processimage.imread(filename) 24 pred_img = np.array(pred_img) 25 pred_img = pred_img.reshape(-1, 28, 28, 1) 26 prediction = model.predict(pred_img) 27 final_prediction = [result.argmax() for result in prediction][0] 28 a = 0 29 for i in prediction[0]: 30 print(a) 31 print('Percent:{:.30%}'.format(i)) 32 a = a+1 33 return final_prediction 34 35 36 def main(): 37 predict = MainPredictImg() 38 res = predict.pred('4.png') 39 print("your number is:-->", res) 40 41 42 if __name__ == '__main__': 43 main()