• TensorFlow保存和载入模型


    首先定义一个tf.train.Saver类:

    saver = tf.train.Saver(max_to_keep=1)

    其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后5个模型,如果设置成0,训练过程中的所有模型都会被保存。

    模型训练好以后,保存模型:

    saver.save(sess, ckpt_dir + "/nn_model.ckpt", global_step=1)

    其中,sess是Session,ckpt_dir + "/nn_model.ckpt"是保存的路径和名称,global_step是模型名称的后缀名,由于我们只保存最后一个模型,所以可以设置为1,如果每一个模型都想保存,可以设置成训练的epoch。

    载入模型比较简单:

    saver.restore(sess, model_file)

    其中,sess是Session,model_file是模型的路径和名称。

  • 相关阅读:
    mogodb优化
    uuid
    ssl详解
    探究rh6上mysql5.6的主从、半同步、GTID多线程、SSL认证主从复制
    CMAKE MYSQL
    证书生成
    叶金荣主页
    mysqlslap
    sysbench 测试MYSQL
    mysql实验室
  • 原文地址:https://www.cnblogs.com/mstk/p/9395589.html
Copyright © 2020-2023  润新知