• Tensorflow同时加载使用多个模型


    在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Session和想使用的模型不匹配导致的错误。而使用多个graph,就需要为每个graph使用不同的Session,但是每个graph也可以在多个Session中使用,这个时候就需要在每个Session使用的时候明确申明使用的graph。

    g1 = tf.Graph() # 加载到Session 1的graph
    g2 = tf.Graph() # 加载到Session 2的graph
    
    sess1 = tf.Session(graph=g1) # Session1
    sess2 = tf.Session(graph=g2) # Session2
    
    # 加载第一个模型
    with sess1.as_default(): 
        with g1.as_default():
            tf.global_variables_initializer().run()
            model_saver = tf.train.Saver(tf.global_variables())
            model_ckpt = tf.train.get_checkpoint_state(“model1/save/path”)
            model_saver.restore(sess, model_ckpt.model_checkpoint_path)
    # 加载第二个模型
    with sess2.as_default():  # 1
        with g2.as_default():  
            tf.global_variables_initializer().run()
            model_saver = tf.train.Saver(tf.global_variables())
            model_ckpt = tf.train.get_checkpoint_state(“model2/save/path”)
            model_saver.restore(sess, model_ckpt.model_checkpoint_path)
    
    ...
    
    # 使用的时候
    with sess1.as_default():
        with sess1.graph.as_default():  # 2
            ...
    
    with sess2.as_default():
        with sess2.graph.as_default():
            ...
    
    # 关闭sess
    sess1.close()
    sess2.close()

    注:1、在1处使用as_default使session在离开的时候并不关闭,在后面可以继续使用知道手动关闭;2、由于有多个graph,所以sess.graph与tf.get_default_value的值是不相等的,因此在进入sess的时候必须sess.graph.as_default()明确申明sess.graph为当前默认graph,否则就会报错。

    PS:不同框架的模型(tf, caffe, torch等)在加载的很有可能导致底层的cuDNN分配出现问题从而报错,这种一般可以尝试通过模型的加载顺序来解决。

    TensorFlow函数:tf.Session()和tf.Session().as_default()的区别

       tf.Session().as_default():创建一个默认会话
    
       那么问题来了,会话和默认会话有什么区别呢?TensorFlow会自动生成一个默认的计算图,如果没有特殊指定,运算会自动加入这个计算图中。TensorFlow中的会话也有类似的机制,但是TensorFlow不会自动生成默认的会话,而是需要手动指定。
    
       tf.Session()创建一个会话,当上下文管理器退出时会话关闭和资源释放自动完成。
    
       tf.Session().as_default()创建一个默认会话,当上下文管理器退出时会话没有关闭,还可以通过调用会话进行run()和eval()操作,代码示例如下:
    tf.Session()代码示例:
    import tensorflow as tf
    a = tf.constant(1.0)
    b = tf.constant(2.0)
    with tf.Session() as sess:
       print(a.eval())   
    print(b.eval(session=sess))
    运行结果如下:
    1.0
    RuntimeError: Attempted to use a closed Session.

    在打印张量b的值时报错,报错为尝试使用一个已经关闭的会话。使用 tf.Session().as_default()不会有这个问题。

    对于run()方法也是一样,如果想让默认会话在退出上下文管理器时关闭会话,可以调用sess.close()方法。

    import tensorflow as tf
    a = tf.constant(1.0)
    b = tf.constant(2.0)
    with tf.Session().as_default() as sess:
       print(a.eval())  
       sess.close()
    print(b.eval(session=sess))
    1.0
    RuntimeError: Attempted to use a closed Session.

     

    参考:
    https://www.tensorflow.org/api_docs/python/tf/Session
    https://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session

    https://www.cnblogs.com/arkenstone/p/7016481.html

  • 相关阅读:
    Mysql数据库的一些操作
    【狂神说Java】JavaWeb入门到实战1---笔记
    图神经网络学习
    HWSX网址
    python如何判断两个数组完全相等?
    SQL-3-菜鸟教程
    SQL-2
    时间序列相似度分析算法
    pip install 安装不了怎么办?
    leetcode 3 无重复字符的最长子串
  • 原文地址:https://www.cnblogs.com/wynlfd/p/13930563.html
Copyright © 2020-2023  润新知