• keras01


      1 import numpy as np
      2 from keras.datasets import mnist
      3 from keras.models import Sequential, Model
      4 from keras.layers.core import Dense, Activation, Dropout
      5 from keras.utils import np_utils
      6 
      7 import matplotlib.pyplot as plt
      8 import matplotlib.image as processimage
      9 
     10 # Load mnist RAW dataset
     11 # 训练集28*28的图片X_train = (60000, 28, 28) 训练集标签Y_train = (60000,1)
     12 # 测试集图片X_test  = (10000, 28, 28) 测试集标签Y_test  = (10000,1)
     13 (X_train, Y_train), (X_test, Y_test) = mnist.load_data()
     14 print(X_train.shape, Y_train.shape)
     15 print(X_test.shape, Y_test.shape)
     16 
     17 '''
     18 第一步,准备数据
     19 '''
     20 # Prepare 准备数据
     21 # Reshape 60k个图片,每个28*28的图片,降维成一个784的一维数组
     22 X_train = X_train.reshape(60000, 784)  # 28*28 = 784
     23 X_test = X_test.reshape(10000, 784)
     24 # set type into float32 设置成浮点型,因为使用的是GPU,GPU可以加速运算浮点型
     25 # CPU使用int型计算会更快
     26 X_train = X_train.astype('float32')  # astype SET AS TYPE INTO
     27 X_test = X_test.astype('float32')
     28 # 归一化颜色
     29 X_train = X_train/255  # 除以255个颜色,X_train(0, 255)-->(0, 1) 更有利于浮点运算
     30 X_test = X_test/255
     31 
     32 '''
     33 第二步,给神经网络设置基本参数
     34 '''
     35 # Prepare basic setups
     36 batch_sizes = 4096  # 一次给神经网络注入多少数据,别超过6万,和GPU内存有关
     37 nb_class = 10  # 设置多少个分类
     38 nb_epochs = 10  # 60k数据训练20次,一般小数据10次就够了
     39 
     40 '''
     41 第三步,设置标签
     42 '''
     43 # Class vectors label(7) into [0,0,0,0,0,0,0,1,0,1]  把7设置成向量
     44 Y_test = np_utils.to_categorical(Y_test, nb_class)  # Label
     45 Y_train = np_utils.to_categorical(Y_train, nb_class)
     46 
     47 '''
     48 第四步,设置网络结构
     49 '''
     50 model = Sequential()  # 顺序搭建层
     51 # 1st layer
     52 model.add(Dense(512, input_shape=(784,)))  # Dense是输出给下一层, input_dim = 784 [X*784]
     53 model.add(Activation('relu'))  # tanh
     54 model.add(Dropout(0.2))  # overfitting
     55 
     56 # 2nd layer
     57 model.add(Dense(256))  # 256是因为上一层已经输出512了,所以不用标注输入
     58 model.add(Activation('relu'))
     59 model.add(Dropout(0.2))
     60 
     61 # 3rd layer
     62 model.add(Dense(10))
     63 model.add(Activation('softmax'))  # 根据10层输出,softmax做分类
     64 
     65 '''
     66 第五步,编译compile
     67 '''
     68 model.compile(
     69     loss='categorical_crossentropy',
     70     optimizer='rmsprop',
     71     metrics=['accuracy']
     72 )
     73 
     74 # 启动网络训练 Fire up
     75 Trainning = model.fit(
     76     X_train, Y_train,
     77     batch_size=batch_sizes,
     78     epochs=nb_epochs,
     79     validation_data=(X_test, Y_test)
     80 )
     81 # 以上就可运行
     82 
     83 '''
     84 最后,检查工作
     85 '''
     86 # Trainning.history  # 检查训练历史
     87 # Trainning.params  # 检查训练参数
     88 
     89 
     90 # 拉取test里的图
     91 testrun = X_test[9999].reshape(1, 784)
     92 
     93 testlabel = Y_test[9999]
     94 print('label:-->', testlabel)
     95 print(testrun.shape)
     96 plt.imshow(testrun.reshape([28, 28]))
     97 
     98 # 判断输出结果
     99 pred = model.predict(testrun)
    100 print(testrun)
    101 print('label of test same Y_test[9999]-->>', testlabel)
    102 print('预测结果-->>', pred)
    103 print([final.argmax() for final in pred])  # 找到pred数组中的最大值
    104 
    105 # 用自己的画的图28*28预测一下 (不太准,可以用卷积)
    106 # 可以用PS创建28*28像素的图,且是灰度,没有色彩
    107 target_img = processimage.imread('/.../picture.jpg')
    108 print(' before reshape:->>', target_img.shape)
    109 plt.imshow(target_img)
    110 target_img = target_img.reshape(1, 784)  # reshape
    111 print(' after reshape:->>', target_img.shape)
    112 
    113 target_img = np.array(target_img)  # img --> numpy array
    114 target_img = target_img.astype('float32')  # int --> float32
    115 target_img /= 255  # (0,255) --> (0,1)
    116 
    117 print(target_img)
    118 
    119 mypred = model.predict(target_img)
    120 print(mypred)
    121 print(myfinal.argmax() for myfinal in mypred)

    参考:https://www.bilibili.com/video/av29806227

  • 相关阅读:
    OSCache报错error while trying to flush writer
    html 输入框验证
    Struts2 一张图片引发的bug
    Html 小插件10 即时新闻
    String
    内部类
    多态
    抽象&接口
    继承
    封装
  • 原文地址:https://www.cnblogs.com/paprikatree/p/10148751.html
Copyright © 2020-2023  润新知