• (六) Keras 模型保存和RNN简单应用


    视频学习来源

    https://www.bilibili.com/video/av40787141?from=search&seid=17003307842787199553

    笔记


    RNN用于图像识别并不是很好


    模型保存(结构和参数)

    1 需要安装h5py

    pip install h5py

    2在代码最后一行

    model.save(‘model.h5’)

    即可在当前目录保存HDF5文件

    模型载入

    1开头导入包

    from keras.models import load_model

    2导入模型

    model=load_model(‘model.h5’)

    模型载入后可接着训练

    model.fit(x_train,y_train,batch_size=64,epochs=2)

    只保存参数

    model.save_weights(‘weights.h5’)

    model.load _weights(‘weights.h5’)

    只保存网络结构

    from keras.models import model_from_json

    json_string=model.to_json()

    model=model_from_json(json_string)


    import numpy as np
    from keras.datasets import mnist  #将会从网络下载mnist数据集
    from keras.utils import np_utils
    from keras.models import Sequential  #序列模型
    from keras.layers import Dense
    from keras.layers.recurrent import SimpleRNN #keras中三种RNN  SimpleRNN,LSTM,GRU
    from keras.optimizers import Adam


    # 数据长度,一行有28个像素
    input_size=28
    # 序列长度,一共有28行
    time_steps=28
    # 隐藏层cell个数
    cell_size=50
    
    
    #载入数据
    (x_train,y_train),(x_test,y_test)=mnist.load_data()
    #查看格式
    #(60000,28,28)
    print('x_shape:',x_train.shape)
    #(60000)
    print('y_shape:',y_train.shape)
    
    
    
    #格式是(60000,28,28)
    #格式是样本数,time_steps(序列长度),input_size(每一个序列的数据长度)
    #如果数据是(60000,784)需要转成(60000,28,28)
    #除以255是做数据归一化处理
    x_train=x_train/255.0 #转换数据格式
    x_test=x_test/255.0 #转换数据格式
    #label标签转换成 one  hot 形式
    y_train=np_utils.to_categorical(y_train,num_classes=10) #分成10类
    y_test=np_utils.to_categorical(y_test,num_classes=10) #分成10类
    
    #定义序列模型
    model=Sequential()
    
    #循环神经网络
    #一个隐藏层
    model.add(SimpleRNN(
        units=cell_size,  #输出
        input_shape=(time_steps,input_size), #输入
    ))
    
    #输出层
    model.add(Dense(10,activation='softmax'))
    
    
    
    #定义优化器
    #学习速率为10的负4次方
    adam=Adam(lr=1e-4)
    
    
    #定义优化器,损失函数,训练效果中计算准确率
    model.compile(
        optimizer=adam, #sgd优化器
        loss='categorical_crossentropy',  #损失用交叉熵,速度会更快
        metrics=['accuracy'],  #计算准确率
    )
    
    #训练
    #六万张,每次训练64张,训练10个周期(六万张全部训练完算一个周期)
    model.fit(x_train,y_train,batch_size=64,epochs=10)
    
    #评估模型
    loss,accuracy=model.evaluate(x_test,y_test)
    
    print('
    test loss',loss)
    print('
    test accuracy',accuracy)
    
    loss,accuracy=model.evaluate(x_train,y_train)
    
    print('
    train loss',loss)
    print('
    train accuracy',accuracy)


    x_shape: (60000, 28, 28)
    y_shape: (60000,)
    Epoch 1/10
    60000/60000 [==============================] - 9s 145us/step - loss: 1.6191 - acc: 0.4629
    Epoch 2/10
    60000/60000 [==============================] - 9s 156us/step - loss: 0.9580 - acc: 0.7103
    Epoch 3/10
    60000/60000 [==============================] - 6s 101us/step - loss: 0.7064 - acc: 0.7934
    Epoch 4/10
    60000/60000 [==============================] - 8s 141us/step - loss: 0.5749 - acc: 0.8344
    Epoch 5/10
    60000/60000 [==============================] - 8s 128us/step - loss: 0.4999 - acc: 0.8550
    Epoch 6/10
    60000/60000 [==============================] - 6s 102us/step - loss: 0.4503 - acc: 0.8689
    Epoch 7/10
    60000/60000 [==============================] - 6s 99us/step - loss: 0.4130 - acc: 0.8808
    Epoch 8/10
    60000/60000 [==============================] - 6s 95us/step - loss: 0.3838 - acc: 0.8891
    Epoch 9/10
    60000/60000 [==============================] - 6s 96us/step - loss: 0.3597 - acc: 0.8969
    Epoch 10/10
    60000/60000 [==============================] - 6s 96us/step - loss: 0.3408 - acc: 0.9020
    10000/10000 [==============================] - 1s 73us/step
    
    test loss 0.3126664091944695
    
    test accuracy 0.91
    60000/60000 [==============================] - 4s 67us/step
    
    train loss 0.326995205249389
    
    train accuracy 0.9060166666666667
  • 相关阅读:
    创建web应用程序时出现 SharePoint HRESULT:0x80070094 问题
    用Javascript获取SharePoint当前登录用户的用户名及Group信息
    javascript连接数据库
    sharepoint 中banner 图片的放大
    GridView导出Excel 类库
    SQL Server 性能调优
    GridView长字段的显示
    MOSS母版页制作学习笔记(二)
    sharepoint 中批量导入导出
    JavaScript 动态更改sharepoint 列表的颜色
  • 原文地址:https://www.cnblogs.com/XUEYEYU/p/keras-learning-6.html
Copyright © 2020-2023  润新知