• tensorflow add_to_collection用法


    训练代码:

    # coding: utf-8
    from __future__ import print_function
    from __future__ import division
    
    import tensorflow as tf
    import numpy as np
    import argparse
    
    
    def dense_to_one_hot(input_data, class_num):
        data_num = input_data.shape[0]
        index_offset = np.arange(data_num) * class_num
        labels_one_hot = np.zeros((data_num, class_num))
        labels_one_hot.flat[index_offset + input_data.ravel()] = 1
        return labels_one_hot
    
    
    def build_parser():
        parser = argparse.ArgumentParser()
        parser.add_argument('--data_path', type=str, required=True)
        parser.add_argument('--model_path', type=str, required=True)
        args = parser.parse_args()
        return args
    
    
    p = build_parser()
    origin = np.genfromtxt(p.data_path, delimiter=',')
    
    data = origin[:, 0:2]
    labels = origin[:, 2]
    
    
    learning_rate = 0.001
    training_epochs = 5000
    display_step = 1
    
    n_features = 2
    n_class = 2
    x = tf.placeholder(tf.float32, [None, n_features], "input")
    y = tf.placeholder(tf.float32, [None, n_class])
    
    W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
    b = tf.Variable(tf.zeros([n_class]), name="b")
    
    scores = tf.nn.xw_plus_b(x, W, b, name='scores')
    pred_proba = tf.nn.softmax(scores, name="pred_proba")
    
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    
    saver = tf.train.Saver()
    tf.add_to_collection('pred_proba', pred_proba)
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(training_epochs):
            result_pred_proba, _, c = sess.run([pred_proba, optimizer, cost],
                                               feed_dict={x: data, y: dense_to_one_hot(labels.astype(int), 2)})
            if epoch % 100 == 0:
                print(c)
        saver.save(sess, p.model_path)
        print("Optimization Finished!")
    

    推理代码:

    # coding: utf-8
    from __future__ import print_function
    from __future__ import division
    
    import tensorflow as tf
    import numpy as np
    import argparse
    
    
    def build_parser():
        parser = argparse.ArgumentParser()
        parser.add_argument('--model_path', type=str, required=True)
        args = parser.parse_args()
        return args
    
    p = build_parser()
    
    with tf.Session() as sess:
        new_saver = tf.train.import_meta_graph(p.model_path + ".meta")
        new_saver.restore(sess, p.model_path)
        pred_proba = tf.get_collection('pred_proba')[0]
        graph = tf.get_default_graph()
        input_x = graph.get_operation_by_name('input').outputs[0]
        r = sess.run(pred_proba, feed_dict={input_x: np.array([[0.6211,5]])})
        print(r)
        print(0 if r[0][0] > r[0][1] else 1)
    

    参考资料

    TensorFlow 模型保存/载入的两种方法

  • 相关阅读:
    bzoj 2337 [HNOI2011]XOR和路径【高斯消元+dp】
    bzoj 3196 Tyvj 1730 二逼平衡树【线段树 套 splay】
    bzoj 3528 [Zjoi2014]星系调查【树链剖分+数学】
    bzoj 2127 happiness【最小割+dinic】
    bzoj 3110 [Zjoi2013]K大数查询【树套树||整体二分】
    bzoj 4137 [FJOI2015]火星商店问题【CDQ分治+可持久化trie】
    运用背景橡皮擦抠透明郁金香
    使用快速通道抠荷花
    抠图总结
    花纹的选区
  • 原文地址:https://www.cnblogs.com/zhouyang209117/p/8424302.html
Copyright © 2020-2023  润新知