• Keras的一些功能函数


    1、模型的信息提取

    1 # 节点信息提取
    2 config = model.get_config()  # 把model中的信息,solver.prototxt和train.prototxt信息提取出来
    3 model = Model.from_config(config)  # 还回去
    4 # or, for Sequential:
    5 model = Sequential.from_config(config) # 重构一个新的Model模型,用去其他训练,fine-tuning比较好用

    2、模型概况查询

    # 1、模型概括打印
    model.summary()
    
    # 2、返回代表模型的JSON字符串,仅包含网络结构,不包含权值。可以从JSON字符串中重构原模型:
    from models import model_from_json
    
    json_string = model.to_json()
    model = model_from_json(json_string)
    
    # 3、model.to_yaml:与model.to_json类似,同样可以从产生的YAML字符串中重构模型
    from models import model_from_yaml
    
    yaml_string = model.to_yaml()
    model = model_from_yaml(yaml_string)
    
    # 4、权重获取
    model.get_layer()      #依据层名或下标获得层对象
    model.get_weights()    #返回模型权重张量的列表,类型为numpy array
    model.set_weights()    #从numpy array里将权重载入给模型,要求数组具有与model.get_weights()相同的形状。
    
    # 查看model中Layer的信息
    model.layers 查看layer信息

    3、模型保存与加载

    model.save_weights(filepath)
    # 将模型权重保存到指定路径,文件类型是HDF5(后缀是.h5)
    
    model.load_weights(filepath, by_name=False)
    # 从HDF5文件中加载权重到当前模型中, 默认情况下模型的结构将保持不变。
    # 如果想将权重载入不同的模型(有些层相同)中,则设置by_name=True,只有名字匹配的层才会载入权重

    4、在keras中设定GPU的大小

    import tensorflow as tf
    from keras.backend.tensorflow_backend import set_session
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.3
    set_session(tf.Session(config=config))

    5、训练与保存模型

    filepath = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
    checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
    # fit model
    model.fit(x, y, epochs=20, verbose=2, callbacks=[checkpoint], validation_data=(x, y))

    6、在keras中使用tensorboard

    RUN = RUN + 1 if 'RUN' in locals() else 1   # locals() 函数会以字典类型返回当前位置的全部局部变量。
    
        LOG_DIR = model_save_path + '/training_logs/run{}'.format(RUN)
        LOG_FILE_PATH = LOG_DIR + '/checkpoint-{epoch:02d}-{val_loss:.4f}.hdf5'   # 模型Log文件以及.h5模型文件存放地址
    
        tensorboard = TensorBoard(log_dir=LOG_DIR, write_images=True)
        checkpoint = ModelCheckpoint(filepath=LOG_FILE_PATH, monitor='val_loss', verbose=1, save_best_only=True)
        early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)
    
        history = model.fit_generator(generator=gen.generate(True), steps_per_epoch=int(gen.train_batches / 4),
                                      validation_data=gen.generate(False), validation_steps=int(gen.val_batches / 4),
                                      epochs=EPOCHS, verbose=1, callbacks=[tensorboard, checkpoint, early_stopping])
    谢谢!
  • 相关阅读:
    [微信产品经理推荐] 有车一族福音,这个小程序能够帮到你很多忙,功能很逆天!
    微信小程序开闸,关于入口、推广、场景的一些观察与思考
    微信小程序体验(2):驴妈妈景区门票即买即游
    微信小程序的机会在于重新理解群组与二维码
    如何为你的微信小程序体积瘦身?
    体验报告:微信小程序在安卓机和苹果机上的区别
    微信小程序体验(1):携程酒店机票火车票
    张小龙宣布微信小程序1月9日发布,并回答了大家最关心的8个问题
    重点必看:小程序的服务范围限制有哪些?
    一些JS常用的方法
  • 原文地址:https://www.cnblogs.com/ylxn/p/10721575.html
Copyright © 2020-2023  润新知