• tensorflow(迭代完成回归逻辑实例)


    今天学习了利用梯度下降法处理数据与mnist数据集中的准确度的实例:

    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    from tensorflow.examples.tutorials.mnist import input_data 
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    print("packs loaded")
    trainimg = mnist.train.images
    trainlabel = mnist.train.labels
    testimg = mnist.train.images
    testlabel = mnist.train.labels
    print("MNIST loaded")
    tf.compat.v1.disable_eager_execution()
    x=tf.compat.v1.placeholder("float",[None,784])
    y=tf.compat.v1.placeholder("float",[None,10])
    W=tf.Variable(tf.zeros([784,10]))
    b=tf.Variable(tf.zeros([10]))
    actv = tf.nn.softmax(tf.matmul(x,W)+b)
    cost=tf.compat.v1.reduce_mean(-tf.compat.v1.reduce_sum(y*tf.compat.v1.log(actv),reduction_indices=1))
    learning_rate=0.01
    optm = tf.compat.v1.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
    accr = tf.reduce_mean(tf.cast(pred,"float"))
    init = tf.compat.v1.global_variables_initializer()
    sess=tf.compat.v1.InteractiveSession()
    arr=np.array([[31,23,4,24,27,34],
                  [18,3,25,0,6,35],
                  [28,14,33,22,20,8],
                  [13,30,21,19,7,9],
                  [16,1,26,32,2,29],
                  [17,12,5,11,10,15]])
    #tf.rank(arr).eval()
    training_epochs = 50
    batch_size = 100
    display_step = 5
    sess = tf.compat.v1.Session()
    sess.run(init)
    for epoch in range(training_epochs):
        avg_cost=0.
        num_batch =int(mnist.train.num_examples/batch_size)
        for i in range(num_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(optm,feed_dict={x:batch_xs,y:batch_ys})
            feeds ={x:batch_xs,y:batch_ys}
            avg_cost+=sess.run(cost,feed_dict=feeds)/num_batch
            if epoch%display_step==0:
                feeds_train = {x:batch_xs,y:batch_ys}
                feeds_test ={x:mnist.test.images,y:mnist.test.labels}
                train_acc =sess.run(accr,feed_dict=feeds_train)
                test_acc=sess.run(accr,feed_dict=feeds_test)
                print("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f"
                     %(epoch,training_epochs,avg_cost,train_acc,test_acc))
    print("DONE")

  • 相关阅读:
    databinding在android studio2.3版本后不再默认支持使用
    html的css选择器
    补间动画Tweened Animations
    ButterKnife使用
    JDK环境变量配置
    Android学习总结——DrawerLayout 侧滑栏点击事件穿透
    Android学习总结——输入法将BottomNavigationBar(底部导航栏)顶上去的问题
    Android学习总结——INSTALL_FAILED_CONFLICTING_PROVIDER
    Android学习总结——去除标题栏
    Android学习总结——强制下线功能(广播)
  • 原文地址:https://www.cnblogs.com/yangxionghao/p/14340208.html
Copyright © 2020-2023  润新知