• tensorflow cnn+rnn基本结构


     1 #CNN
     2 x = tf.placeholder(tf.float32,[None,input_node],name="x_input")
     3 y_ = tf.placeholder(tf.float32,[None,output_node],name="y_output")
     4 
     5 #input-->layer1
     6 w_1 = tf.Variable(tf.truncted_normal([input_node,L1_node],stdev=0.5))
     7 b_1 = tf.Variable(tf.constant(0.1,shape=[L1_node]))
     8 l_conv1 = tf.nn.relu(tf.matmul(x,w_1)+b_1,strides=[1,2,2,1])
     9 l_pool1 = tf.nn.max_pool(l_conv1,strides=[1,2,2,1],ksize = [1,2,2,1],padding='SAME')
    10 
    11 #layer1-->layder2
    12 w_2 = tf.Variable(tf.truncted_normal([L1_node,L2_node],stddev=0.5))
    13 b_2 = tf.Variable(tf.constant(0.1,shape=[L2_node]))
    14 l_conv2 = tf.nn.relu(tf.matmul(l_pool1,w_2)+b_2)
    15 l_pool2 = tf.nn.max_pool(l_conv2,strides=[1,2,2,1],ksize = [1,2,2,1],padding='SAME')
    16 
    17 #layser2-->fc
    18 w_3 = tf.Variable(tf.truncted_normal([L2_node,fc_node],stddev=0.5))
    19 b_3 = tf.Variable(tf.constant(0.1,shape=[fc_node]))
    20 l_3 = tf.reshape(l_pool2,[-1,])
    21 fc_1 = tf.nn.relu(tf.matmul(l_3,w_3)+b_3)
    22 
    23 #fc-->dropout
    24 drop = tf.nn.dropout(fc_1,keep_prob)
    25 
    26 #dropout-->softmax
    27 w_4 = tf.Variable(tf.truncted_normal([fc_node,output_node],stddev=0.5))
    28 b_4 = tf.Variable(tf.constant(0.1,shape=[output_node]))
    29 y = tf.nn.softmax(tf.matmul(drop,w_4)+b_4)
    30 
    31 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=tf.argmax(y,1),labels=tf.argmax(y_,1))
    32 cross_entropy_mean = tf.reduce_mean(cross_entropy)
    33 loss = cross_entropy+reularation
    34 
    35 train_step = tf.train.GradientDescentOptimizer(leraning_rate).minimize(loss)#以何种方式何种学习率去优化何种目标
    36 
    37 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    38 accuracy = tf.reduce_mean(tf.cast(correct_predictiontf.float32))
    39 
    40 with tf.session() as sess:
    41     tf.global_variable_initializer().run()
    42 
    43     for i in range(max_steps):
    44         sess.run(train_step,feed_dict={x:,y:})
    45 
    46         if i%1000 == 0:
    47             validate_accu = sess.run(accuracy,feed_dict={x:x_val,y:y_val})
    48 
    49     test_accu = sess.run(accuracy,feed_dict = {x:x_test,y:y_dict})
    50 
    51 #RNN
    52 input_size = 28(28*28 image)
    53 hidden_size = 256
    54 layer_num = 2
    55 class_num =10
    56 
    57 x = tf.placeholder(tf.float32,[None,784])
    58 y = tf.placeholder(tf.float32,[None,class_num])
    59 keep_prob = tf.placeholder(tf.float32)
    60 
    61 x = tf.reshape(x,[-1,28,28])
    62 
    63 #一层lstm
    64 lstm_layer = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_size,forget_bias=1.0,..)
    65 
    66 #添加dropout
    67 lstm_layer = tf.contrib.rnn.DropoutWrrapper(lstm_layer,input_keep_prob =1.0,output_keep_prob=keep_prob)
    68 
    69 #堆叠多层
    70 mlstm = tf.contrib.rnn.MultiRNNCell([lstm_layer]*layer_sum,...)
    71 
    72 init_state = mlstm.zero_state(batch_size,dtype=tf.float32)
    73 
    74 output = mlstm(x)
    75 
    76 #添加softmax层
    77 w = tf.Variable(tf.truncted_normal([hidden_size,class_num],stddev=0.1),dtype=tf.float32)
    78 b = tf.Variable(tf.constant(0.1,shape=[class_num]),dtype=tf.float32)
    79 y_ = tf.nn.softmax(tf.matmul(output,w)+b)
    80 
    81 cross_entropy = tf.reduce_mean(-y*tf.log(y_))
    82 train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
    83 
    84 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    85 #tf.argmax(input, axis=None, name=None, dimension=None)此函数是对矩阵按行或列计算最大值,0:按列,此处按行
    86 accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))#tf.cast():数据格式转换,此处bool-->float
    87 
    88 with tf.Session as sess:
    89 
    90     sess.run(train_step,feed_dict={x:,y:,keep_prob:}) #train
    91 
    92     if i%1000 ==0:
    93         train_accuracy = sess.run(accuracy,feed_dict={x:x_val,y:y_val,keep_prob:})
    94         print(train_accuracy)
    95 
    96     #测试集    
    97     test_accuracy = sess.run(accuracy,feed_dict={x:x_test,y:y_test,keep_prob:})
    98     print(test_accuracy)
  • 相关阅读:
    GridView的TemplateField
    数据源绑定
    hihocoder-1415 后缀数组三·重复旋律3 两个字符串的最长公共子串
    hihocoder-1407 后缀数组二·重复旋律2 不重合 最少重复K次
    hdu number number number 斐波那契数列 思维
    最长上升子序列 nlogn
    hdu-4507 吉哥系列故事——恨7不成妻 数位DP 状态转移分析/极限取模
    hdu-3652 B-number 数位DP
    hdu-2089 不要62 基础DP 模板
    字符串hash
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/7486510.html
Copyright © 2020-2023  润新知