导入优化的frozen graph时遇到异常。
# read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
在此行中获取异常:
tf.import_graph_def(graph_def)
ValueError: Input 0 of node import/save/Assign was passed float from import/beta1_power:0 incompatible with expected float_ref.
解决方案:确保你的pb_file格式正确(类似这样),并尝试在import_graph_def()的'name'参数中设置一些值,以尝试覆盖“import”默认值,如下所示:
import tensorflow as tf
from tensorflow.python.platform import gfile
model_path="/tmp/frozen/dcgan.pb"
# read graph definition
f = gfile.FastGFile(model_path, "rb")
gd = graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# fix nodes
for node in graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
# import graph into session
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)