• keras基于卷积网络手写数字识别


    import time
    
    import keras
    from keras.utils import np_utils
    
    start = time.time()
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    rows = 28
    cols = 28
    CLASSES = 10
    x_train = x_train.reshape(x_train.shape[0], rows, cols, 1)
    x_test = x_test.reshape(x_test.shape[0], rows, cols, 1)
    y_train = np_utils.to_categorical(y_train, CLASSES)
    y_test = np_utils.to_categorical(y_test, CLASSES)
    
    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")
    x_train /= 255
    x_test /= 255
    
    model = keras.models.Sequential([
        keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=x_train.shape[1:]),
        keras.layers.MaxPool2D(pool_size=(2, 2)),
        keras.layers.Conv2D(32, (3, 3), activation='relu'),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(10, activation='softmax')
    ])
    model.summary()
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    model.fit(x_train, y_train, batch_size=64, epochs=5)
    evaluate = model.evaluate(x_test, y_test)
    print(evaluate)
    print("elapsed: ", time.time() - start)
    model.save("mnist-con.h5")
    

      

  • 相关阅读:
    Redis(二)
    Redis(一)
    MyBatis--一级二级缓存
    MySQL优化
    HashMap
    ArrayList
    常用框架注解说明
    Linux常用基础命令
    SpringCloud--gateway路由配置
    JetBrains系列软件的插件安装
  • 原文地址:https://www.cnblogs.com/yytxdy/p/11686457.html
Copyright © 2020-2023  润新知