• tensorflow1.0 构建lstm做图片分类


    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #this is data
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    lr = 0.001
    train_iters = 10000
    batch_size = 128
    display_step = 10
    
    n_inputs = 28
    n_steps = 28
    n_hidden_unis = 128
    n_classes = 10
    
    x = tf.placeholder(tf.float32,[None,n_steps,n_inputs])
    y = tf.placeholder(tf.float32,[None,n_classes])
    
    #define weight
    weights = {
        #(28,128)
        "in":tf.Variable(tf.random_normal([n_inputs,n_hidden_unis])),
        #(128,10)
        "out":tf.Variable(tf.random_normal([n_hidden_unis,n_classes]))
    }
    biases = {
        #(128,)
        "in":tf.Variable(tf.constant(0.1,shape=[n_hidden_unis,])),
        #(10,)
        "out":tf.Variable(tf.constant(0.1,shape=[n_classes,]))
    }
    
    
    def RNN(X,weights,biases):
        #形状变换成lstm可以训练的维度
        X = tf.reshape(X,[-1,n_inputs])     #(128*28,28)
        X_in = tf.matmul(X,weights["in"])+biases["in"]  #(128*28,128)
        X_in = tf.reshape(X_in,[-1,n_steps,n_hidden_unis]) #(128,28,128)
    
        #cell
        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_unis,forget_bias=1.0,state_is_tuple=True)
        #lstm cell is divided into two parts(c_state,m_state)
        _init_state = lstm_cell.zero_state(batch_size,dtype=tf.float32)
    
        outputs,states = tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=_init_state,time_major = False)
    
        #outputs
        # results = tf.matmul(states[1],weights["out"])+biases["out"]
        #or
        outputs = tf.transpose(outputs,[1,0,2])
        results = tf.matmul(outputs[-1],weights["out"])+biases["out"]
    
        return results
    
    
    pred = RNN(x,weights,biases)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
    train_op = tf.train.AdamOptimizer(lr).minimize(loss)
    
    correct_pred = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
    
    init = tf.initialize_all_variables()
    
    with tf.Session() as sess:
        sess.run(init)
        step = 0
        while step*batch_size < train_iters:
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            batch_xs = batch_xs.reshape([batch_size,n_steps,n_inputs])
            sess.run(train_op,feed_dict={x:batch_xs,y:batch_ys})
            if step%20 ==0:
                print(sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys}))
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    sql语句中where后边的哪些条件会使索引失效 SQL语句优化
    jvm 判断对象死亡
    mysql数据库优化概述详解
    java集合框架详解
    jvm 图形化工具之jconsole
    java io框架详解
    多台Linux之间文件共享
    二 redis的安装启动
    jvm 线上命令工具
    java 线程6种状态的转换
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12495412.html
Copyright © 2020-2023  润新知