• Tensorflow Learning1 模型的保存和恢复



    CKPT->pb

    Demo

    解析

    tensor name 和 node name 的区别

    Pb 的恢复



    CKPT->pb

    tensorflow的模型保存有两种形式:

    1. ckpt:可以恢复图和变量,继续做训练

    2. pb : 将图序列化,变量成为固定的值,,只可以做inference;不能继续训练


    Demo


      1 def freeze_graph(input_checkpoint,output_graph):
      2 
      3     '''
      4     :param input_checkpoint:
      5     :param output_graph: PB模型保存路径
      6     :return
      7       void
      8     '''
      9 
     10     # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
     11     # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
     12 
     13     # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
     14     output_node_names = "InceptionV3/Logits/SpatialSqueeze" # 如果是多个输出节点,使用 ‘,’号隔开
     15 
     16     ############################     Step1: 从ckpt中恢复图:     #############################################
     17     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
     18     graph = tf.get_default_graph() # 获得默认的图, 可以省略
     19     input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图,可以省略
     20 
     21     with tf.Session() as sess: # 会使用默认的图 作为当前的图
     22         saver.restore(sess, input_checkpoint) #恢复图并得到数据
     23 
     24         ########################     Step2: 创建持久化对象,指定sess,图、以及输出的序列化节点信息    ##############
     25         output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
     26             sess=sess,
     27             input_graph_def=input_graph_def,# 等于:sess.graph_def
     28             output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
     29         #########################    Step3: 模型持久化   #######################################################
     30         with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
     31             f.write(output_graph_def.SerializeToString()) #序列化输出
     32         print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
     33         # for op in graph.get_operations():
     34 
     35         #     print(op.name, op.values())
     36 
     37 
     38 ########################### 调用方式 ################################
     39 # 输入ckpt模型路径
     40 input_checkpoint='models/model.ckpt-10000'
     41 # 输出pb模型的路径
     42 out_pb_path="models/pb/frozen_model.pb"
     43 # 调用freeze_graph将ckpt转为pb
     44 freeze_graph(input_checkpoint,out_pb_path)

    解析

    函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。

    freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

    在保存pb的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称;

    tensor name 和 node name 的区别

    node name 是 图 的节点,里面包含了很多操作和tensor

    tensor 是 node 里面的一个组成部分;

    以input 为例,“input:0”是张量的名称,而"input"表示的是节点的名称

    PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"input:0"


  • 相关阅读:
    getElement方法封装
    使用Ajax (put delete ) django原生CBV 出现csrf token解决办法
    (IO模型介绍,阻塞IO,非阻塞IO,多路复用IO,异步IO,IO模型比较分析,selectors模块,垃圾回收机制)
    协程介绍, Greenlet模块,Gevent模块,Genvent之同步与异步
    Thread类的其他方法,同步锁,死锁与递归锁,信号量,事件,条件,定时器,队列,Python标准模块--concurrent.futures
    线程概念( 线程的特点,进程与线程的关系, 线程和python理论知识,线程的创建)
    进程同步控制(锁,信号量,事件), 进程通讯(队列和管道,生产者消费者模型) 数据共享(进程池和mutiprocess.Pool模块)
    在Python程序中的进程操作,multiprocess.Process模块
    进程前戏 (操作系统简述 什么是进程)
    django ModelForm
  • 原文地址:https://www.cnblogs.com/greentomlee/p/11494383.html
Copyright © 2020-2023  润新知