• tensorflow保存和恢复模型saver.restore


    1.本文只对一些细节点做补充,大体的步骤就不详述了
    2.保存模型
    ① 首先我使用的是tensorflow-gpu 1.4.0
    ② 这个版本生成的ckpt文件是这样的:

    其中.meta存放的是网络模型和所有的变量;
    .index 和.data一起存放变量数据
    -0 -500表示checkpoint点
    ③ 保存的配置(一定细看代码注释!!!)

    import tensorflow as tf
    w1 = tf.Variable(变量的初始化, name='w1')
    w2 = tf.Variable(变量的初始化, name='w2')
    saver = tf.train.Saver([w1,w2],max_to_keep=5, keep_checkpoint_every_n_hours=2)   # 这里是细节部分,可以指定保存的变量,每两小时保存最近的5个模型
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './checkpoint_dir/MyModel',global_step=step,write_meta_graph=False))   # 因为模型没必要多次保存,所以写为False
    

    3.恢复模型(一定细看代码注释!!!)
    代码:

    import tensorflow as tf
    with tf.Session() as sess:    
        saver = tf.train.import_meta_graph(模型路径)  # 模型路径中必须指定到具体的模型下如:xx.ckpt-500.meta,且一般来讲,所有模型都是一样的,如果没有改变模型的条件下。
        # 下面的restore就是在当前的sess下恢复了所有的变量
        saver.restore(sess,数据路径)  # 数据路径也必须指定到具体某个模型的数据,但创建这个路径的方法很多,比如调用最后一个保存的模型tf.train.latest_checkpoint('./checkpoint_dir'),也可以是xx.ckpt-500.data,并且这两个是等效的,如果是xx.ckpt-0.data,就是第一个模型的数据
        print(sess.run('w1:0'))  # 这里的w1必须加上:0
    

    ————————————————
    版权声明:本文为CSDN博主「做一只AI小能手」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    原文链接:https://blog.csdn.net/qq_37285386/article/details/88957558

  • 相关阅读:
    Maven之——仓库(下)
    《Java虚拟机原理图解》3、JVM执行时数据区
    uva 784 Maze Exploration(简单dfs)
    debian下安装mysql
    关于Go语言,自己定义结构体标签的一个妙用.
    [BI项目记]-BUG处理
    怎样加入cocostudio生成的UI到项目
    python 将有序list打乱
    Android之旅十四 android中的xml文件解析
    h5 localStorage存储大小(转)
  • 原文地址:https://www.cnblogs.com/ArdenWang/p/15350397.html
Copyright © 2020-2023  润新知