• 用Keras搭建神经网络 简单模版(五)——RNN LSTM Regressor 循环神经网络


    # -*- coding: utf-8 -*-
    import numpy as np
    np.random.seed(1337)
    import matplotlib.pyplot as plt
    from keras.models import Sequential
    from keras.layers import LSTM,TimeDistributed,Dense
    from keras.optimizers import Adam
    
    BATCH_START = 0 
    TIME_STEPS = 20 
    BATCH_SIZE = 50
    INPUT_SIZE = 1 
    OUTPUT_SIZE = 1
    CELL_SIZE = 20
    LR = 0.006
    
    def get_batch():
        global BATCH_START,TIME_STEPS
        # xs shape(50,20,)
        #xs=np.arange(0,0+20*50).reshape(50,20)
        xs = np.arange(BATCH_START,BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE,TIME_STEPS)) / (10*np.pi)
        seq = np.sin(xs)
        res = np.cos(xs)
        BATCH_START += TIME_STEPS
        #plt.plot(xs[0,:],res[0,:],'r',xs[0,:],seq[0,:],'b--')
        #plt.show()
        return [seq[:,:,np.newaxis], res[:,:,np.newaxis],xs]
    
    #get_batch()
    #exit()
        
        
    model = Sequential()
    
    model.add(LSTM(output_dim=CELL_SIZE, 
                   return_sequences=True, # 每一个时间点都输出一个output
                   batch_input_shape=(BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
                   stateful = True,# batch和batch之间是否有联系
                   # 前一个batch的最后一步和后一个batch的第一步是有联系的
            )) 
    
    model.add(TimeDistributed(Dense(OUTPUT_SIZE))) # dense对每一个output连接,对每一个时间点都要计算
    
    adam = Adam(LR)
    model.compile(optimizer = adam,
                  loss = 'mse',)
    
    print('Training ------------')
    for step in range(501):
        # data shape = (batch_num,steps,inputs/output)
        X_batch, Y_batch, xs = get_batch()
        cost = model.train_on_batch(X_batch, Y_batch)
        pred = model.predict(X_batch,BATCH_SIZE)
        plt.plot(xs[0,:], Y_batch[0].flatten(),'r',xs[0,:],pred.flatten()[:TIME_STEPS],'b--')
        plt.ylim((-1.2,1.2))
        plt.draw()
        plt.pause(0.5)
        if step % 10 == 0:
            print('train cost',cost)

  • 相关阅读:
    微信小程序之阻止冒泡事件
    微信小程序之生成二维码
    微信小程序之数据缓存和数据获取
    微信小程序之分享功能
    抽丝剥茧——策略设计模式
    抽丝剥茧——单例设计模式
    抽丝剥茧——备忘录设计模式
    手把手教你Smarty缓存技术(转)
    二级域名session 共享方案(转)
    MySQL监控、性能分析——工具篇
  • 原文地址:https://www.cnblogs.com/caiyishuai/p/11311340.html
Copyright © 2020-2023  润新知