• tensorflow的hello world


      1 import tensorflow as tf;
      2 from tensorflow.examples.tutorials.mnist import input_data
      3 
      4 ##定义网络结构
      5 input_nodes  = 784
      6 output_nodes = 10
      7 layer1_nodes = 500
      8 #定义超参数
      9 #自动设置学习率
     10 learning_rate_base=  0.8;
     11 learning_decay = 0.99   ;
     12 decay_step=100          ;
     13 
     14 #滑动平均
     15 moving_average__decay = 0.99
     16 regularizer_rate  = 0.0001;
     17 train_step=30000
     18 batch_size= 100
     19 
     20 
     21 def inference(tensor1,weight1,bias1,weight2,bias2,average_class=None):
     22     if(average_class==None):
     23         layer1=tf.nn.relu(   tf.matmul(tensor1,weight1)+ bias1 )
     24         return tf.matmul( layer1,weight2 ) + bias2
     25     else:
     26         layer1 = tf.nn.relu(tf.matmul(tensor1, average_class.average(weight1)) + average_class.average(bias1))
     27         return tf.matmul(layer1, average_class.average(weight2) ) + average_class.average(bias2)
     28 
     29 def get_weight(shape):
     30     weight=tf.Variable(tf.truncated_normal(shape=shape,stddev=0.1),tf.float32)
     31     tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer_rate)(weight))
     32     return weight
     33 
     34 def get_bias(shape):
     35     return tf.Variable(tf.zeros(shape))
     36 
     37 def train(mnist):
     38     #定义输入输出
     39     train_x=tf.placeholder(tf.float32,shape=[None,input_nodes],name='train_x')
     40     train_y=tf.placeholder(tf.float32,shape=[None,output_nodes],name='train_y' )
     41 
     42     weight1=get_weight( [input_nodes,layer1_nodes] )
     43     bias1   =get_bias([layer1_nodes])
     44 
     45     weight2=get_weight([layer1_nodes,output_nodes]);
     46     bias2  =get_bias([output_nodes])
     47     results = inference(train_x, weight1, bias1, weight2, bias2, None)
     48 
     49     #定义学习率
     50     global_step = tf.Variable(0, trainable=False)
     51     learning_rate = tf.train.exponential_decay(learning_rate_base, global_step,  mnist.train.num_examples / batch_size, learning_decay,staircase=True)
     52 
     53     #定义损失、优化器
     54 
     55     ce= tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=results,labels=tf.argmax( train_y,1) ) )
     56     loss=ce+tf.add_n( tf.get_collection('losses') )
     57     tf.summary.scalar('lost',loss)
     58 
     59     optimizer= tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step);
     60 
     61     #定义滑动平均
     62     ema = tf.train.ExponentialMovingAverage(moving_average__decay, global_step);
     63     maintain_average_op = ema.apply( tf.trainable_variables())
     64     with tf.control_dependencies([optimizer,maintain_average_op]):
     65         train_op=tf.no_op(name='train')
     66 
     67     #预测准确率
     68     average_y=inference(train_x,weight1,bias1,weight2,bias2,ema);
     69     correction_prediction = tf.equal(  tf.argmax( average_y,1 ) ,tf.argmax(train_y,1))
     70     accuracy = tf.reduce_mean(tf.cast(correction_prediction,tf.float32));
     71 
     72     with tf.Session() as sess:
     73         tf.global_variables_initializer().run()
     74 
     75         validate_feed={train_x:mnist.validation.images,train_y:mnist.validation.labels}
     76         test_feed    ={train_x:mnist.test.images,train_y:mnist.test.labels}
     77 
     78         #汇总
     79         merged_summary_op = tf.summary.merge_all()
     80         summaryWriter = tf.summary.FileWriter('./log/mnist_with_summaries',sess.graph)
     81 
     82         #迭代训练
     83         for i in range(train_step):
     84             if(i%1000 == 0 ):
     85                 validate_acc=sess.run(accuracy,feed_dict=validate_feed);
     86                 print('After %d training steps,using aaverage model is %g '%(i,validate_acc))
     87 
     88             xt,yt=mnist.train.next_batch(batch_size);
     89             sess.run( train_op,feed_dict={ train_x :xt,train_y:yt}          );
     90             summary_str=sess.run( merged_summary_op,feed_dict={ train_x :xt,train_y:yt} );
     91             summaryWriter.add_summary(summary_str,i)
     92 
     93 
     94         test_acc=sess.run(accuracy,feed_dict=test_feed)
     95         print('accuracy is %g'%(test_acc));
     96 def main():
     97     mnist= input_data.read_data_sets('./MNIST_data',one_hot=True)
     98     train(mnist);
     99 
    100 if __name__ == '__main__':
    101     main()
  • 相关阅读:
    JAVA单例MongoDB工具类
    Docker的安装使用-第1章
    JSON支持什么对象/类型?
    Linux环境源码编译安装SVN
    网站优化总结
    [java]反射1 2017-06-25 21:50 79人阅读 评论(10) 收藏
    记一次问题的解决,web自动化用例的管理
    将GatlingBundle容器化,并通过参数化来执行压测
    基于Fitnesse的接口自动化测试-关键字设计-样例-mysql操作
    基于Fitnesse的接口自动化测试-关键字设计-样例-redis操作
  • 原文地址:https://www.cnblogs.com/z-bear/p/10455547.html
Copyright © 2020-2023  润新知