• 『TensorFlow』第七弹_保存&载入会话_霸王回马


    首更:

    由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe...好吧,caffe也没有这种python风格的设定...

    废话少说,导入包:

    1 import numpy as np
    2 import tensorflow as tf

    保存会话:

    1 W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32)
    2 b = tf.Variable([[1,2,3]],dtype=tf.float32)
    3 
    4 init = tf.global_variables_initializer()
    5 saver = tf.train.Saver() # <---------
    6 
    7 with tf.Session() as sess:
    8     sess.run(init)
    9     save_path = saver.save(sess,'./my_net/saver_net.ckpt') # <---------

    载入会话:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32)
    2 b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess: 
    7     saver.restore(sess,'./my_net/saver_net.ckpt') # <---------
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(b))

     输出如下:

    Weight:
     [[ 1.  2.  3.]
     [ 4.  5.  6.]]
    biases:
     [[ 1.  2.  3.]]
    

     载入会话会加载之前保存的变量,所以不需要tf.global_variables_initializer()激活本次变量了。

     再更:

    引入节点名称后,只要tf变量节点的名称一致,python变量名不一致也能完美继承,也就是说tf变量节点的名称识别权限大于python变量名

    详细的命名规则下节有介绍:『TensorFlow』第八弹_变量与命名空间_固有结界

    保存模型:

    1 W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name='W') # <------
    2 b = tf.Variable([[1,2,3]],dtype=tf.float32,name='b')         # <------
    3 
    4 init = tf.global_variables_initializer()
    5 saver = tf.train.Saver()
    6 
    7 with tf.Session() as sess:
    8     sess.run(init)
    9     save_path = saver.save(sess,'./my_net/saver_net.ckpt')

    W--’W‘,b--’b‘

    载入模型:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32') # <------
    2 b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32') # <------
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess:
    7     saver.restore(sess,'./my_net/saver_net.ckpt')
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(b))

    W,b

    结果报错

    载入模型:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name='W') # <------
    2 a = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name='b') # <------
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess:
    7     saver.restore(sess,'./my_net/saver_net.ckpt')
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(a))

    W-’W‘,a--’b'

    1 INFO:tensorflow:Restoring parameters from ./my_net/saver_net.ckpt
    2 Weight:
    3  [[ 1.  2.  3.]
    4  [ 4.  5.  6.]]
    5 biases:
    6  [[ 1.  2.  3.]]
  • 相关阅读:
    eclipse 中 debug-config
    release稳定版本/snapshot快照版本
    nginx.config文件配置
    用 Spring Boot 和 MybatisPlus 快速构建项目
    github 生成ssh key
    Vagrant安装virtualbox
    修改linux默认时区
    《加密与解密》笔记
    manjaro 安装显卡驱动
    排序算法-C++实现
  • 原文地址:https://www.cnblogs.com/hellcat/p/6899683.html
Copyright © 2020-2023  润新知