• Keras 识别手写数字 MKY


    Keras 识别手写数字

    from keras.utils import np_utils
    from keras.datasets import mnist
    from keras.models import Sequential
    from keras.layers import Dense, Activation
    import numpy as np
    #生成训练和测试数据集
    (x_data, x_label),(y_data, y_label) = mnist.load_data()
    #生成验证数据
    x_val = x_data[50000:]
    x_val_label = x_label[50000:]
    x_data = x_data[:50000]
    x_label = x_label[:50000]
    #数据维度
    x_data.shape, x_label.shape, y_data.shape, y_label.shape, x_val.shape, x_val_label.shape

    """
     ((50000, 28, 28),(50000,),(10000, 28, 28),(10000,),(10000, 28, 28),(10000,))
    """

    #预处理,将三维转成二维
    x_data = x_data.reshape(50000784).astype('float32') / 255.0
    y_data = y_data.reshape(10000784).astype('float32') / 255.0
    x_val = x_val.reshape(10000784).astype('float32') / 255.0
    #抽取700个标本
    train_rand = np.random.choice(50000700)
    val_rand = np.random.choice(10000300)
    #重新生成训练和验证数据集数据
    x_data = x_data[train_rand]
    x_label = x_label[train_rand]


    x_val = x_val[val_rand]
    x_val_label = x_val_label[val_rand]

    x_data.shape, x_val_label.shape, x_val.shape, x_val_label.shape
    """
    ((700, 784), (300,), (300, 784), (300,))
    """

    #one-hot编码
    x_label = np_utils.to_categorical(x_label)
    x_val_label = np_utils.to_categorical(x_val_label)
    y_label = np_utils.to_categorical(y_label)
    #搭建模型
    model = Sequential()
    model.add(Dense(2, input_dim=28*28, activation='relu'))
    model.add(Dense(10, activation='softmax'))
    model.summary()
    """
        Model: "sequential_1"
        _________________________________________________________________
        Layer (type)                 Output Shape              Param #   
        =================================================================
        dense_1 (Dense)              (None, 2)                 1570      
        _________________________________________________________________
        dense_2 (Dense)              (None, 10)                30        
        =================================================================
        Total params: 1,600
        Trainable params: 1,600
        Non-trainable params: 0
        _________________________________________________________________
    """

    #编译
    model.compile(optimizer='Adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy']
                 )  
    #训练
    his = model.fit(x_data,
                    x_label,
                    epochs=1000
                    batch_size=10
                    validation_data=(y_data, y_label))
    """
        Train on 700 samples, validate on 10000 samples
        Epoch 1/1000


        2022-06-15 12:38:00.042906: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10


        700/700 [==============================] - 2s 3ms/step - loss: 2.2821 - accuracy: 0.1714 - val_loss: 2.2173 - val_accuracy: 0.2314
        Epoch 2/1000
        700/700 [==============================] - 2s 3ms/step - loss: 2.1893 - accuracy: 0.2114 - val_loss: 2.1363 - val_accuracy: 0.2008
        ..........................................................................................
        Epoch 3/1000
        Epoch 997/1000
        700/700 [==============================] - 2s 3ms/step - loss: 0.3267 - accuracy: 0.9114 - val_loss: 9.6746 - val_accuracy: 0.4252
        Epoch 998/1000
        700/700 [==============================] - 2s 3ms/step - loss: 0.3299 - accuracy: 0.9114 - val_loss: 9.9539 - val_accuracy: 0.4229
        Epoch 999/1000
        700/700 [==============================] - 2s 3ms/step - loss: 0.3312 - accuracy: 0.9071 - val_loss: 9.7498 - val_accuracy: 0.4248
        Epoch 1000/1000
        700/700 [==============================] - 2s 3ms/step - loss: 0.3347 - accuracy: 0.9014 - val_loss: 9.4837 - val_accuracy: 0.4229
    """

    %matplotlib inline
    import matplotlib.pyplot as plt
    fig, loss_ax = plt.subplots()
    acc_ax = loss_ax.twinx()


    loss_ax.plot(his.history['loss'], 'y', label='train loss')
    loss_ax.plot(his.history['val_loss'], 'r', label='val loss')

    loss_ax.plot(his.history['accuracy'], 'b', label='train acc')
    loss_ax.plot(his.history['val_accuracy'], 'g', label='val acc')



    loss_ax.set_xlabel('epoch')
    loss_ax.set_ylabel('loss')
    acc_ax.set_xlabel('accuracy')

    loss_ax.legend(loc='upper left')
    acc_ax.legend(loc='lower left')
    plt.show()
  • 相关阅读:
    poj 2251
    poj 1321
    poj 2777
    poj 3468
    poj 2318
    javascript
    buhui
    swift 构造器
    mac上不了网
    字体
  • 原文地址:https://www.cnblogs.com/menkeyi/p/16382559.html
Copyright © 2020-2023  润新知