原代码:
……
model = ResNet50(weights='imagenet')
……
def main():
x = model.predict(x)
改为:
……
import tensorflow as tf
from tensorflow.python.keras.backend import set_session
from tensorflow.python.keras.models import load_model
tf_config = tf.ConfigProto()
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()
set_session(sess)
model = ResNet50(weights='imagenet')
……
def main():
global sess
global graph
with graph.as_default():
set_session(sess)
x = model.predict(x)