• 如何查看tf SavedModel的输入/输出等信息?


    参考链接:https://juejin.im/post/6844903693184172040

    查看模型的Signature签名

    Tensorflow提供了一个工具

    • 如果你下载了Tensorflow的源码,可以找到这样一个文件,./tensorflow/python/tools/saved_model_cli.py
    • 如果你安装了tensorflow,也可以用下边的命令查看tensorflow源码位置和版本:
    import tensorflow as tf
    print tf.__path__
    print tf.__version__
    

    你可以加上-h参数查看saved_model_cli.py脚本的帮助信息:

    usage: saved_model_cli.py [-h] [-v] {show,run,scan} ...
    
    saved_model_cli: Command-line interface for SavedModel
    
    optional arguments:
      -h, --help       show this help message and exit
      -v, --version    show program's version number and exit
    
    commands:
      valid commands
    
      {show,run,scan}  additional help
    

    如果你安装

    指定SavedModel模所在的位置,我们就可以显示SavedModel的模型信息:

    python path/to/tensorflow/python/tools/saved_model_cli.py show --dir ./model/ --all
    

    显示类似结果

    MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
    
    signature_def['predict']:
      The given SavedModel SignatureDef contains the following input(s):
        inputs['myInput'] tensor_info:
            dtype: DT_FLOAT
            shape: (-1, 784)
            name: myInput:0
      The given SavedModel SignatureDef contains the following output(s):
        outputs['myOutput'] tensor_info:
            dtype: DT_FLOAT
            shape: (-1, 10)
            name: Softmax:0
      Method name is: tensorflow/serving/predict
    

    查看模型的计算图

    了解tensflow的人可能知道TensorBoard是一个非常强大的工具,能够显示很多模型信息,其中包括计算图。问题是,TensorBoard需要模型训练时的log,如果这个SavedModel模型是别人训练好的呢?办法也不是没有,我们可以写一段代码,加载这个模型,然后输出summary info,代码如下:

    import tensorflow as tf
    import sys
    from tensorflow.python.platform import gfile
    
    from tensorflow.core.protobuf import saved_model_pb2
    from tensorflow.python.util import compat
    
    with tf.Session() as sess:
      model_filename ='./model/saved_model.pb'
      with gfile.FastGFile(model_filename, 'rb') as f:
    
        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
    
        if 1 != len(sm.meta_graphs):
          print('More than one graph found. Not sure which to write')
          sys.exit(1)
    
        g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
    LOGDIR='./logdir'
    train_writer = tf.summary.FileWriter(LOGDIR)
    train_writer.add_graph(sess.graph)
    train_writer.flush()
    train_writer.close()
    

    代码中,将汇总信息输出到logdir,接着启动TensorBoard,加上上面的logdir:

    tensorboard --logdir ./logdir
    

    在浏览器中输入地址: http://127.0.0.1:6006/ ,就可以看到如下的计算图:

  • 相关阅读:
    Java MVC和三层架构
    EL表达式
    EL表达式中的11个隐式对象
    JDBC连接数据库7个步骤
    JSP九大内置对象和四个作用域
    Eclipse常用快捷键大全
    Java的绝对路径和相对路径
    Servlet中相对路径与绝对路径
    mysql8的深坑
    mysql单列索引和联合索引
  • 原文地址:https://www.cnblogs.com/CheeseZH/p/13524009.html
Copyright © 2020-2023  润新知