• 测试1


    # encoding: UTF-8
    import tensorflow as tf
    import numpy as np
    from tensorflow.examples.tutorials.mnist import input_data as mnist_data
    import tensorflow as tf
    from tensorflow.python.platform import gfile
    import os
    
    print("Tensorflow version " + tf.__version__)
    print(tf.__path__)
    
    # tf.set_random_seed(0)
    
    # # 输入mnist数据
    # mnist = mnist_data.read_data_sets("data", one_hot=True)
    
    # #输入数据
    # x = tf.placeholder("float", [None, 784])
    # y_ = tf.placeholder("float", [None,10])
    
    # #权值输入
    # W = tf.Variable(tf.zeros([784,10]))
    # b = tf.Variable(tf.zeros([10]))
    # #神经网络输出
    # y = tf.nn.softmax(tf.matmul(x,W) + b)
    
    # #设置交叉熵
    # cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    
    # #设置训练模型
    # learningRate = 0.005
    # train_step = tf.train.GradientDescentOptimizer(learningRate).minimize(cross_entropy)
    
    # init = tf.initialize_all_variables()
    # sess = tf.Session()
    # sess.run(init)
    
    # itnum = 1000;
    # batch_size = 100;
    # for i in range(itnum):
    #     if i % 100 == 0:
    #         print("the index " + str(i + 1) + " train")
    #     batch_xs, batch_ys = mnist.train.next_batch(batch_size)
    #     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    
    # correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    # accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    # print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
    
    
    def train():
        height = 28
        width = 28
        inchannel = 1
        outchannel = 2
    
        #conv0 (64, 112, 112) kernel (3, 3) stride (1, 1) pad (1, 1)
        wkernel = 3
        stride = 1
        pad = 1
        dilate  = 1
    
        w = np.arange(wkernel * wkernel * inchannel * outchannel).reshape((outchannel,inchannel,wkernel,wkernel))
        b = np.array([0])
        data = np.arange(height * width * inchannel).reshape((1,inchannel,height,width))
        print('input:',data)
        print('weight:',w)
    
        data = data.transpose(0,3,2,1)
        w = w.transpose(3,2,1,0)
        # print('input:',data)
        # print('inputshape:',data.shape)
        # print('weight:',w)
        # print('weight:',w.shape)
        input = tf.Variable(data, dtype=np.float32, name="input")
        #input_reshape = tf.reshape(input, [1,inchannel,height,width])
        filter = tf.Variable(w, dtype=np.float32,name="weight")
    
        conv = tf.nn.conv2d(input, filter, strides=[1, stride, stride, 1], padding='SAME', name = "conv")
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            #print("input: 
    ", sess.run(input))
            #input_reshape = sess.run(input).transpose(0,3,2,1)
            #print("input_reshape: 
    ", input_reshape)
            #print("filter: 
    ", sess.run(filter))
            #filter_reshape = sess.run(filter).transpose(3,2,1,0)
            #print("filter_reshape: 
    ", filter_reshape)
            #print("conv ", sess.run(conv))
            conv_reshape = sess.run(conv).transpose(0,3,2,1)
            print("conv_reshape: 
    ", conv_reshape)
    
            # tf_prelu_reshape = sess.run(tf_prelu).transpose(0,3,2,1)
            # print("tf_prelu_reshape: 
    ", tf_prelu_reshape)
    
            # tf_bn_reshape = sess.run(tf_bn).transpose(0,3,2,1)
            # print("tf_bn_reshape: 
    ", tf_bn_reshape)
    
            export_dir = "log"
            saver = tf.train.Saver()
            step = 200
            import os
            checkpoint_file = os.path.join(export_dir, 'model.ckpt')
            saver.save(sess, checkpoint_file, global_step=step)
    
            graph = tf.get_default_graph()
            checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
            _ = tf.train.import_meta_graph(checkpoint_file)
            summary_write = tf.summary.FileWriter(export_dir , graph)
    
    
    if __name__ == '__main__':
        train()
  • 相关阅读:
    【转】【python】装饰器的原理
    Common Subsequence---最长公共子序列
    N个数的全排列 -------指定排头法
    Oil Deposits----深搜问题
    Worm
    亲和串
    n个数的最小公倍数
    整除的尾数
    Substrings 子字符串-----搜索
    N的互质数----欧拉函数
  • 原文地址:https://www.cnblogs.com/adong7639/p/9227033.html
Copyright © 2020-2023  润新知