训练代码:
# coding: utf-8
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import argparse
def dense_to_one_hot(input_data, class_num):
data_num = input_data.shape[0]
index_offset = np.arange(data_num) * class_num
labels_one_hot = np.zeros((data_num, class_num))
labels_one_hot.flat[index_offset + input_data.ravel()] = 1
return labels_one_hot
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_dir', type=str, required=True)
args = parser.parse_args()
return args
p = build_parser()
origin = np.genfromtxt(p.data_path, delimiter=',')
data = origin[:, 0:2]
labels = origin[:, 2]
learning_rate = 0.001
training_epochs = 5000
display_step = 1
n_features = 2
n_class = 2
x = tf.placeholder(tf.float32, [None, n_features], "input")
y = tf.placeholder(tf.float32, [None, n_class])
W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
b = tf.Variable(tf.zeros([n_class]), name="b")
scores = tf.nn.xw_plus_b(x, W, b, name='scores')
pred_proba = tf.nn.softmax(scores, name="pred_proba")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
saver = tf.train.Saver()
tf.add_to_collection('pred_proba', pred_proba)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
result_pred_proba, _, c = sess.run([pred_proba, optimizer, cost],
feed_dict={x: data, y: dense_to_one_hot(labels.astype(int), 2)})
if epoch % 100 == 0:
print(c)
builder = tf.saved_model.builder.SavedModelBuilder(p.model_dir)
inputs = {'input': tf.saved_model.utils.build_tensor_info(x)}
outputs = {'pred_proba': tf.saved_model.utils.build_tensor_info(pred_proba)}
signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')
builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature': signature})
builder.save()
推理代码:
# coding: utf-8
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import argparse
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, required=True)
args = parser.parse_args()
return args
p = build_parser()
with tf.Session() as sess:
signature_key = 'test_signature'
input_key = 'input'
output_key = 'pred_proba'
meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], p.model_dir)
signature = meta_graph_def.signature_def
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
r = sess.run(y, feed_dict={x: np.array([[0.6211, 5]])})
print(r)
print(0 if r[0][0] > r[0][1] else 1)