• RNN预测字母


    #字母预测:输入a预测出b,输入b预测出c,输入c预测出d,输入d预测出e,输入e预测出a
    #10000  a
    #01000  b
    #00100  c
    #00010  d
    #00001  e
    
    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.layers import Dense,SimpleRNN
    import matplotlib.pyplot as plt
    import os
    
    input_word='abcde'
    w_to_id={'a':0,'b':1,'c':2,'d':3,'e':4,}
    id_to_onehot={0:[1.,0.,0.,0.,0.],1:[0.,1.,0.,0.,0.],2:[0.,0.,1.,0.,0.],3:[0.,0.,0.,1.,0.],4:[0.,0.,0.,0.,1.]}
    x_train=[id_to_onehot[w_to_id['a']],id_to_onehot[w_to_id['b']],id_to_onehot[w_to_id['c']],id_to_onehot[w_to_id['d']],id_to_onehot[w_to_id['e']]]
    y_train=[w_to_id['b'],w_to_id['c'],w_to_id['d'],w_to_id['e'],w_to_id['a']]
    
    np.random.seed(7)
    np.random.shuffle(x_train)
    np.random.seed(7)
    np.random.shuffle(y_train)
    tf.random.set_seed(7)
    
    #使x_train符合SimpleRNN输入要求:[送入样本数,循环核时间展开步数,每个时间步输入特征个数]
    #此处整个数据集送入,所以送入研样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1;表示独热码有5个输入特征,每个时间步输入特征个数为5
    x_train=np.reshape(x_train,(len(x_train),1,5))
    y_train=np.array(y_train)
    
    model=tf.keras.Sequential([SimpleRNN(3),Dense(5,activation='softmax')])
    
    model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy'])
    
    checkpoint_save_path='./checkpoint/rnn_onehot_lprel.ckpt'
    
    if os.path.exists(checkpoint_save_path+'.index'):
        print('-------------------load the model--------------')
        model.load_weights(checkpoint_save_path)
    
    cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,
                                                   save_best_only=True,
                                                   monitor='loss')  #由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型
    
    history=model.fit(x_train,y_train,batch_size=32,epochs=50,callbacks=[cp_callback])
    
    model.summary()
    # print(model.trainable_variables)
    file = open('./rnn_weights.txt', 'w')
    for v in model.trainable_variables:
        file.write(str(v.name) + '
    ')
        file.write(str(v.shape) + '
    ')
        file.write(str(v.numpy()) + '
    ')
    file.close()
    
    ###############################################    show   ###############################################
    
    # 显示训练集和验证集的acc和loss曲线
    acc = history.history['sparse_categorical_accuracy']
    loss = history.history['loss']
    
    
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.title('Training Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.title('Training Loss')
    plt.legend()
    plt.show()
    
    preNum=int(input('input the number of test alphabet'))
    
    for i in range(preNum):
        alphabet1=input('input test alphabet:')
        alphabet=[id_to_onehot[w_to_id[alphabet1]]]
        #使alphabet符合SimpleRNN输入要求[送入样本数,循环核时间展开步数,每个时间步输入特征个数]
        #使此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1;独热码有5个输入特征,每个时间步输入特征个数为5
        alphabet=np.reshape(alphabet,(1,1,5))
        reseult=model.predict([alphabet])
        pred=tf.argmax(reseult,axis=1)
        pred=int(pred)
        tf.print(alphabet1 + '->' + input_word[pred])
  • 相关阅读:
    php 处理 json_encode 中文显示问题
    php输出cvs文件,下载cvs文件
    php服务器端生成csv文件
    在VS2013中强制IIS Express应用程序池使用经典模式
    align=absMiddle属性设置
    30个惊人的插件来扩展 Twitter Bootstrap
    jquery.fullCalendar官方文档翻译(一款小巧好用的日程管理日历, 可集成Google Calendar)
    jquery操作select(取值,设置选中)
    Bootstrap Paginator 分页 demo.
    uniform 中checkbox通过jquery 选中
  • 原文地址:https://www.cnblogs.com/python2/p/13610445.html
Copyright © 2020-2023  润新知