• 吴裕雄--天生自然TensorFlow高层封装:使用TensorFlow-Slim处理MNIST数据集实现LeNet-5模型


    # 1. 通过TensorFlow-Slim定义卷机神经网络
    import numpy as np
    import tensorflow as tf
    import tensorflow.contrib.slim as slim
    
    from tensorflow.examples.tutorials.mnist import input_data
    
    # 通过TensorFlow-Slim来定义LeNet-5的网络结构。
    def lenet5(inputs):
        inputs = tf.reshape(inputs, [-1, 28, 28, 1])
        net = slim.conv2d(inputs, 32, [5, 5], padding='SAME', scope='layer1-conv')
        net = slim.max_pool2d(net, 2, stride=2, scope='layer2-max-pool')
        net = slim.conv2d(net, 64, [5, 5], padding='SAME', scope='layer3-conv')
        net = slim.max_pool2d(net, 2, stride=2, scope='layer4-max-pool')
        net = slim.flatten(net, scope='flatten')
        net = slim.fully_connected(net, 500, scope='layer5')
        net = slim.fully_connected(net, 10, scope='output')
        return net
    
    # 训练模型。
    def train(mnist):
        x = tf.placeholder(tf.float32, [None, 784], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
        y = lenet5(x)
    
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
        loss = tf.reduce_mean(cross_entropy)
    
        train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            for i in range(3000):
                xs, ys = mnist.train.next_batch(100)
                _, loss_value = sess.run([train_op, loss], feed_dict={x: xs, y_: ys})
                if i % 1000 == 0:
                    print("After %d training step(s), loss on training batch is %g." % (i, loss_value))
    #  主程序
    def main(argv=None):
        mnist = input_data.read_data_sets("F:\TensorFlowGoogle\201806-github\datasets\MNIST_data\", one_hot=True)
        train(mnist)
    
    if __name__ == '__main__':
        main()

  • 相关阅读:
    python for selenium 数据驱动测试
    iframe 处理
    百度登录、退出示例
    js弹框处理
    window.alert弹出处理
    通过网页内容识别和处理弹出窗口
    通过识别标题处理弹出窗口
    通过名称识别和处理弹出窗口
    发现github官网进不去了
    开发环境的搭建-记录一下
  • 原文地址:https://www.cnblogs.com/tszr/p/12069987.html
Copyright © 2020-2023  润新知