• 莫烦Python 4


    莫烦Python 4

    RNN Classifier 循环神经网络

    问题描述

    使用RNN对MNIST里面的图片进行分类

    关键

    SimpleRNN()参数

    • batch_input_shape
      使用状态RNN的注意事项

    可以将RNN设置为‘stateful’,意味着由每个batch计算出的状态都会被重用于初始化下一个batch的初始状态。状态RNN假设连续的两个batch之中,相同下标的元素有一一映射关系。

    要启用状态RNN,请在实例化层对象时指定参数stateful=True,并在Sequential模型使用固定大小的batch:通过在模型的第一层传入batch_size=(…)和input_shape来实现。在函数式模型中,对所有的输入都要指定相同的batch_size。

    如果要将循环层的状态重置,请调用.reset_states(),对模型调用将重置模型中所有状态RNN的状态。对单个层调用则只重置该层的状态。

    (samples,timesteps,input_dim)

    代码

    '''
    RNN Classifier 循环神经网络
    '''
    import numpy as np
    np.random.seed(1337)
    
    from keras.datasets import mnist
    from keras.utils import  np_utils
    from keras.models import Sequential
    from keras.layers import SimpleRNN, Activation, Dense
    from keras.optimizers import  Adam
    
    time_step = 28
    input_size = 28
    batch_size = 50
    output_size = 10
    cell_size = 50
    LR = 0.001
    
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    
    X_train = X_train.reshape(-1, 28, 28) / 255.      # normalize
    X_test = X_test.reshape(-1, 28, 28) / 255.        # normalize
    y_train = np_utils.to_categorical(y_train, num_classes=10)
    y_test = np_utils.to_categorical(y_test, num_classes=10)
    
    model = Sequential()
    model.add(
        SimpleRNN(
            batch_input_shape=(None, time_step, input_size),
            units=cell_size
        )
    )
    
    model.add(
        Dense(output_size)
    )
    
    model.add(Activation('softmax'))
    
    adam = Adam(LR)
    model.compile(
        optimizer=adam,
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    model.summary()
    
    model.fit(X_train, y_train, batch_size=batch_size, epochs=2, verbose=2, validation_data=(X_test, y_test))
    

    结果

    Model: "sequential_2"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    simple_rnn_2 (SimpleRNN)     (None, 50)                3950      
    _________________________________________________________________
    dense_2 (Dense)              (None, 10)                510       
    _________________________________________________________________
    activation_2 (Activation)    (None, 10)                0         
    =================================================================
    Total params: 4,460
    Trainable params: 4,460
    Non-trainable params: 0
    _________________________________________________________________
    Train on 60000 samples, validate on 10000 samples
    Epoch 1/2
     - 12s - loss: 0.6643 - accuracy: 0.7966 - val_loss: 0.4501 - val_accuracy: 0.8550
    Epoch 2/2
     - 9s - loss: 0.3220 - accuracy: 0.9087 - val_loss: 0.2445 - val_accuracy: 0.9359
  • 相关阅读:
    阿里云播放器弹幕选型
    使用swiper组件,轮播图在高分辨率情况下变形,图片拉高该如何解决?
    解决图片无法设置hover,以设置图片的阴影
    当标题文字超出长度后,后续用...来代替
    windows 安装wget
    【Go学习】GO中...的用法
    【Go】go test
    tcpdump工具及使用介绍
    leetcode32.最长有效括号
    Global Round 21 部分题解
  • 原文地址:https://www.cnblogs.com/Howbin/p/12599404.html
Copyright © 2020-2023  润新知