• 学习进度笔记


    学习进度笔记08

    TensorFlow逻辑回归

    1. import tensorflow as tf  
    2. from tensorflow.examples.tutorials.mnist import input_data  
    3. mnist=input_data.read_data_sets("/home/yxcx/tf_data",one_hot=True)  
    4. import os  
    5. os.environ["CUDA_VISIBLE_DEVICES"]="0"  
    6. #Parameters  
    7. learning_rate=0.01  
    8. training_epochs=25  
    9. batch_size=100  
    10. display_step=1  
    11. #tf Graph Input  
    12. x=tf.placeholder(tf.float32,[None,784])  
    13. y=tf.placeholder(tf.float32,[None,10])  
    14. #Set model weights  
    15. W=tf.Variable(tf.zeros([784,10]))  
    16. b=tf.Variable(tf.zeros([10]))  
    17. #Construct model  
    18. pred=tf.nn.softmax(tf.matmul(x,W)+b)  
    19. #Minimize error using cross entropy  
    20. cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))  
    21. #Gradient Descent  
    22. optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)  
    23. #Initialize the variables  
    24. init=tf.global_variables_initializer()  
    25. #Start training  
    26. with tf.Session() as sess:  
    27. sess.run(init)  
    28. #Training cycle  
    29. for epoch in range(training_epochs):  
    30. avg_cost=0  
    31. total_batch=int(mnist.train.num_examples/batch_size)  
    32. # loop over all batches  
    33. for i in range(total_batch):  
    34. batch_xs,batch_ys=mnist.train.next_batch(batch_size)  
    35. #Fit training using batch data  
    36. _,c=sess.run([optimizer,cost],feed_dict={x:batch_xs,y:batch_ys})  
    37. #Conpute average loss  
    38. avg_cost+= c/total_batch  
    39. if (epoch+1) % display_step==0:  
    40. print("Epoch:",'%04d' % (epoch+1),"Cost:" ,"{:.09f}".format(avg_cost))  
    41. print("Optimization Finished!")  
    42. #Test model  
    43. correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))  
    44. # Calculate accuracy for 3000 examples  
    45. accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))  
    46. print("Accuracy:",accuracy.eval({x:mnist.test.images[:3000],y:mnist.test.labels[:3000]}))  
  • 相关阅读:
    缓存之雪崩现象与穿透现象
    Linux下安装php的memcached扩展(memcache的客户端)
    Linux下编译、安装php
    Linux下编译、安装并启动apache
    Linux下编译、安装并启动memcached
    memcached内存分配机制
    Memcached的过期数据的过期机制及删除机制(LRU)
    linux下mysql的root密码忘记----解决方案
    Linux服务管理
    Python中import机制
  • 原文地址:https://www.cnblogs.com/xueqiuxiang/p/14466975.html
Copyright © 2020-2023  润新知