• 深入理解Keras中LSTM的stateful和stateless应用区别


    本文通过让LSTM学习字母表,来预测下一个字母,详细的请参考:

    https://blog.csdn.net/zwqjoy/article/details/80493341

    https://machinelearningmastery.com/understanding-stateful-lstm-recurrent-neural-networks-python-keras/

    一、Stateful模式预测下一个字母

    # Stateful LSTM to learn one-char to one-char mapping
    import numpy
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.layers import LSTM
    from keras.utils import np_utils
    # fix random seed for reproducibility
    numpy.random.seed(7)
    # define the raw dataset
    alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    # create mapping of characters to integers (0-25) and the reverse
    char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    int_to_char = dict((i, c) for i, c in enumerate(alphabet))
    # prepare the dataset of input to output pairs encoded as integers
    seq_length = 1
    dataX = []
    dataY = []
    for i in range(0, len(alphabet) - seq_length, 1):
        seq_in = alphabet[i:i + seq_length]
        seq_out = alphabet[i + seq_length]
        dataX.append([char_to_int[char] for char in seq_in])
        dataY.append(char_to_int[seq_out])
        print (seq_in, '->', seq_out)
    # reshape X to be [samples, time steps, features]
    X = numpy.reshape(dataX, (len(dataX), seq_length, 1))
    # normalize
    X = X / float(len(alphabet))
    # one hot encode the output variable
    y = np_utils.to_categorical(dataY)
    # create and fit the model
    batch_size = 1
    model = Sequential()
    model.add(LSTM(16, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))
    model.add(Dense(y.shape[1], activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    for i in range(300):
        model.fit(X, y, epochs=1, batch_size=batch_size, verbose=2, shuffle=False)
        model.reset_states()
    # summarize performance of the model
    scores = model.evaluate(X, y, batch_size=batch_size, verbose=0)
    model.reset_states()
    print("Model Accuracy: %.2f%%" % (scores[1]*100))

    OUT:

    Model Accuracy: 100.00%

    模型训练后预测一下:

    model.reset_states()#这个时候我们重置一下状态,那么就会从字母表的开头开始
    # demonstrate some model predictions
    seed = [char_to_int[alphabet[0]]]
    for i in range(0, len(alphabet)-1):
        x = numpy.reshape(seed, (1, len(seed), 1))
        x = x / float(len(alphabet))
        prediction = model.predict(x, verbose=0)
        index = numpy.argmax(prediction)
        print (int_to_char[seed[0]], "->", int_to_char[index])
        seed = [index]

    OUT:

    A -> B
    B -> C
    C -> D
    D -> E
    E -> F
    F -> G
    G -> H
    H -> I
    I -> J
    J -> K
    K -> L
    L -> M
    M -> N
    N -> O
    O -> P
    P -> Q
    Q -> R
    R -> S
    S -> T
    T -> U
    U -> V
    V -> W
    W -> X
    X -> Y
    Y -> Z

    那么如果我们从中间字母开始预测呢?
    model.reset_states()#这个时候我们依然先重置一下状态
    # demonstrate a random starting point
    letter = "K"
    seed = [char_to_int[letter]]
    print ("New start: ", letter)
    for i in range(0, 5):
        x = numpy.reshape(seed, (1, len(seed), 1))
        x = x / float(len(alphabet))
        prediction = model.predict(x, verbose=0)
        index = numpy.argmax(prediction)
        print (int_to_char[seed[0]], "->", int_to_char[index])
        seed = [index]

    OUT:

    New start:  K
    K -> B
    B -> C
    C -> D
    D -> E
    E -> F
    我们可以看到,重置状态后,即便是从中间的字母K开始预测,接下来输出依然是从字母表开始输出一样输出B,这说明前一个状态的输入Ct-1的作用是大于本次的输入xt的
    如果我们不重置状态,直接从中间字母开始呢?
    # demonstrate a random starting point
    letter = "K"
    seed = [char_to_int[letter]]
    print ("New start: ", letter)
    for i in range(0, 5):
        x = numpy.reshape(seed, (1, len(seed), 1))
        x = x / float(len(alphabet))
        prediction = model.predict(x, verbose=0)
        index = numpy.argmax(prediction)
        print (int_to_char[seed[0]], "->", int_to_char[index])
        seed = [index]

    OUT:

    New start:  K
    K -> Z
    Z -> Z
    Z -> Z
    Z -> Z
    Z -> Z
    我们可以看到,没有重置状态,直接预测,输入的状态依然是接着上一次的最后输出状态开始的,所以都预测成了Z,再次说明了上一次的状态输入其作用大于本次的输入。

    二、Stateless模式预测下一个字母

           从上面stateful模式我们可以看出,需要经常重置状态,否则状态就会不停延续上一次,有时候并不需要状态一直延续;stateless模式默认的就是自动重置状态,而且stateless模式可以完成大部分的任务,因为连续的timestep往往就放在一个sample里来循环了,样本之间尽量保持独立,所以不要滥用stateful模式。

      下面我们用stateless模式来让模型学习字母表中的随机子序列,从而预测下一个字母,序列最大长度为5,不够补0填充,各个子序列之间独立,这种模式可能是现实中应用比较多的。

    import numpy as np
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.layers import LSTM
    from keras.utils import np_utils
    from keras.preprocessing.sequence import pad_sequences
    from keras import callbacks
    
    #构建字母与数字之间的映射字典
    np.random.seed(7)
    alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    # create mapping of characters to integers (0-25) and the reverse
    char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    int_to_char = dict((i, c) for i, c in enumerate(alphabet))
    
    #构建模型样本,序列长度最大为5,1000个训练样本
    num_inputs = 1000
    max_len = 5
    dataX = []
    dataY = []
    for i in range(num_inputs):
        start = np.random.randint(len(alphabet)-2)
        end = np.random.randint(start, min(start+max_len,len(alphabet)-1))
        sequence_in = alphabet[start:end+1]
        sequence_out = alphabet[end + 1]
        dataX.append([char_to_int[char] for char in sequence_in])
        dataY.append(char_to_int[sequence_out])
        print (sequence_in, '->', sequence_out)
    
    # convert list of lists to array and pad sequences if needed
    X = pad_sequences(dataX, maxlen=max_len, dtype='float32')
    # reshape X to be [samples, time steps, features]
    X = np.reshape(X, (X.shape[0], max_len, 1))
    # normalize
    X = X / float(len(alphabet))
    # one hot encode the output variable
    y = np_utils.to_categorical(dataY,26)
    
    #构建并运行模型
    batch_size = 1
    model = Sequential()
    model.add(LSTM(32, input_shape=(X.shape[1], 1),return_sequences=True))
    model.add(LSTM(32))
    model.add(Dense(y.shape[1], activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    remote = callbacks.RemoteMonitor(root='http://localhost:9000')
    model.fit(X, y, epochs=200, batch_size=batch_size, verbose=2, callbacks=[remote])
    # summarize performance of the model
    scores = model.evaluate(X, y, verbose=0)
    print("Model Accuracy: %.2f%%" % (scores[1]*100))

    OUT:

    Model Accuracy: 100.00%

    模型预测:

    #模型预测
    for i in range(20):
        pattern_index = np.random.randint(len(dataX))
        pattern = dataX[pattern_index]
        x = pad_sequences([pattern], maxlen=max_len, dtype='float32')
        x = np.reshape(x, (1, max_len, 1))
        x = x / float(len(alphabet))
        prediction = model.predict(x, verbose=0)
        index = np.argmax(prediction)
        result = int_to_char[index]
        seq_in = [int_to_char[value] for value in pattern]
        print (seq_in, "->", result)

    OUT:

    ['J'] -> K
    ['H', 'I', 'J'] -> K
    ['E', 'F'] -> G
    ['K', 'L', 'M'] -> N
    ['B'] -> C
    ['C'] -> D
    ['R', 'S'] -> T
    ['A', 'B', 'C'] -> D
    ['C', 'D', 'E'] -> F
    ['N', 'O', 'P'] -> Q
    ['C', 'D'] -> E
    ['L', 'M'] -> N
    ['F', 'G', 'H', 'I', 'J'] -> K
    ['N', 'O', 'P', 'Q'] -> R
    ['C', 'D', 'E', 'F', 'G'] -> H
    ['A', 'B', 'C'] -> D
    ['R', 'S', 'T', 'U', 'V'] -> W
    ['B', 'C', 'D'] -> E
    ['F', 'G'] -> H
    ['K'] -> L

    可以看出随便输入长度不超过5的序列,都能正确预测下一个字母,不用再不停手动重置状态。

    import pandas as pd
    def predict(seq='A'):
        seq_in=[s for s in seq if 'A'<=s<='Z']
        x=pd.Series(seq_in).map(char_to_int)
        if len(x)==0:
            return ''
        else:
            x = pad_sequences([x], maxlen=max_len, dtype='float32')
            x = np.reshape(x, (1, max_len, 1))
            x = x / float(len(alphabet))
            prediction = model.predict(x, verbose=0)
            index = np.argmax(prediction)
            result = int_to_char[index]
            print (seq_in, "->", result)
    predict('OP')

    OUT:
    ['O', 'P'] -> Q

  • 相关阅读:
    随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比
    stringstream读入每行数据
    java Log4j封装,程序任何位置调用
    Oracle 归档模式和非归档模式
    为什么需要 RPC 服务?
    JFrame windowbuiler的使用基础
    Eclipse安装windowsbuilder
    字符串反转
    static{}静态代码块与{}普通代码块之间的区别
    jQuery EasyUI 数据网格
  • 原文地址:https://www.cnblogs.com/gczr/p/13414964.html
Copyright © 2020-2023  润新知