• 基于tensorflow的MNIST手写识别


    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环。我也是!

    这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟。在TensorFlow的中文介绍文档中的内容,有些可能与你使用的tensorflow的版本不一致了,我这里用到的tensorflow的版本就有这个问题。 另外,还给大家说下,例子中的MNIST所用到的资源图片,在原始的官网上,估计很多人都下载不到了。我也提供一下下载地址。

    我的tensorflow的版本信息:

    >>> import tensorflow as tf
    >>> print tf.VERSION    
    1.0.1
    >>> print tf.GIT_VERSION
    v1.0.0-65-g4763edf-dirty
    >>> print tf.COMPILER_VERSION
    4.8.4

    下面,就看看,我参考的中文tensorflow网站的代码,在自己的环境里,运行的结果。

     1 [root@bogon tensorflow]# python
     2 Python 2.7.5 (default, Nov  6 2016, 00:28:07) 
     3 [GCC 4.8.5 20150623 (Red Hat 4.8.5-11)] on linux2
     4 Type "help", "copyright", "credits" or "license" for more information.
     5 >>> import tensorflow.examples.tutorials.mnist.input_data as input_data
     6 >>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
     7 Traceback (most recent call last):
     8   File "<stdin>", line 1, in <module>
     9   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py", line 211, in read_data_sets
    10     SOURCE_URL + TRAIN_IMAGES)
    11   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 208, in maybe_download
    12     temp_file_name, _ = urlretrieve_with_retry(source_url)
    13   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 165, in wrapped_fn
    14     return fn(*args, **kwargs)
    15   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 190, in urlretrieve_with_retry
    16     return urllib.request.urlretrieve(url, filename)
    17   File "/usr/lib64/python2.7/urllib.py", line 94, in urlretrieve
    18     return _urlopener.retrieve(url, filename, reporthook, data)
    19   File "/usr/lib64/python2.7/urllib.py", line 240, in retrieve
    20     fp = self.open(url, data)
    21   File "/usr/lib64/python2.7/urllib.py", line 203, in open
    22     return self.open_unknown_proxy(proxy, fullurl, data)
    23   File "/usr/lib64/python2.7/urllib.py", line 222, in open_unknown_proxy
    24     raise IOError, ('url error', 'invalid proxy for %s' % type, proxy)
    25 IOError: [Errno url error] invalid proxy for http: '10.90.1.101:8080'
    26 >>> 
    27 >>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    28 Extracting MNIST_data/train-images-idx3-ubyte.gz
    29 Extracting MNIST_data/train-labels-idx1-ubyte.gz
    30 Extracting MNIST_data/t10k-images-idx3-ubyte.gz
    31 Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
    32 >>> import tensorflow as tf
    33 >>> x = tf.placeholder(tf.float32, [None, 784])
    34 >>> W = tf.Variable(tf.zeros([784,10]))
    35 >>> b = tf.Variable(tf.zeros([10]))
    36 >>> y = tf.nn.softmax(tf.matmul(x,W) + b)
    37 >>> y_ = tf.placeholder("float", [None,10])
    38 >>> cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    39 >>> train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    40 >>> init = tf.initialize_all_variables()
    41 WARNING:tensorflow:From <stdin>:1: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
    42 Instructions for updating:
    43 Use `tf.global_variables_initializer` instead.
    44 >>> init = tf.global_variables_initializer()   
    45 >>> sess = tf.Session()
    46 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
    47 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
    48 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
    49 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
    50 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
    51 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
    52 >>> sess.run(init)
    53 >>> for i in range(1000):
    54 ...   batch_xs, batch_ys = mnist.train.next_batch(100)
    55 ...   sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    56 ... 
    57 >>> correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    58 >>> accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    59 >>> print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
    60 0.9088
    61 >>> 

    上述日志,是我的测试全过程记录,上面反映的信息有如下几点:

    1. 红色部分的错误,因为我本地机器是通过代理上网的,这个过程中,tensorflow会用urllib进行MNIST的图片资源的下载,由于网络问题,资源文件下载失败。

    2. 都有哪些资源文件要下载呢?追踪日志中的文件/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py第211行前后:

    def read_data_sets(train_dir,
                       fake_data=False,
                       one_hot=False,
                       dtype=dtypes.float32,
                       reshape=True,
                       validation_size=5000):
      if fake_data:
    
        def fake():
          return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
    
        train = fake()
        validation = fake()
        test = fake()
        return base.Datasets(train=train, validation=validation, test=test)
    
      TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
      TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
      TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
      TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
    
      local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                       SOURCE_URL + TRAIN_IMAGES)
      with open(local_file, 'rb') as f:
        train_images = extract_images(f)
    
      local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                       SOURCE_URL + TRAIN_LABELS)
      with open(local_file, 'rb') as f:
        train_labels = extract_labels(f, one_hot=one_hot)
    
      local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                       SOURCE_URL + TEST_IMAGES)
      with open(local_file, 'rb') as f:
        test_images = extract_images(f)
    
      local_file = base.maybe_download(TEST_LABELS, train_dir,
                                       SOURCE_URL + TEST_LABELS)
      with open(local_file, 'rb') as f:
        test_labels = extract_labels(f, one_hot=one_hot)
    
      if not 0 <= validation_size <= len(train_images):
        raise ValueError(
            'Validation size should be between 0 and {}. Received: {}.'
            .format(len(train_images), validation_size))
    
      validation_images = train_images[:validation_size]
      validation_labels = train_labels[:validation_size]
      train_images = train_images[validation_size:]
      train_labels = train_labels[validation_size:]
    
      train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
      validation = DataSet(validation_images,
                           validation_labels,
                           dtype=dtype,
                           reshape=reshape)
      test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)
    
      return base.Datasets(train=train, validation=validation, test=test)

    看到上面红色的部分,就是这里需要下载的图片资源文件。这个,我的网络环境是下载不了的。我通过其他途径下载到了这里需要的资源。我将下载的图片资源,放在了我进入python时所在的路径下。虽然直接下载没有成功,但是在当前路径下还是创建了MNIST_data的目录的。如下图,红色圈目录就是程序创建的目录。我将下载的train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz放在MNIST_data目录了

    然后,再次执行mnist = input_data.read_data_sets("MNIST_data/", one_hot=True),就ok了,不会报错。得到28-31行的输出信息。

    3. 执行到第40行的代码时,爆出WARNING,提示用新的函数,按照提示信息,执行了第41行的代码,OK。说明版本兼容性,在tensorflow中需要注意

    4. 执行后,得到结果,如60行显示,识别率为0.9088。

    关于MNIST的这个例子的手写识别性能的理论,不是本博文的重点,读者可以参照MNIST相关的文章自行学习。

    最后,附上MNIST这个例子中,用到的资源图片下载地址,点击进行下载。(说明:需要积分才能下载的,谅解)

  • 相关阅读:
    pyexharts教程
    elasticsearch常用查询语句
    kubelet连接apiserver报TLS错误
    k8s init.yaml
    bareapi nginx配置
    traefik配置https
    kubernetes中通过static pod部署elasticsearch生产集群
    聊天服务器架构
    使用JAX-RS的实现Jersey创建RESTful Web服务
    SpringBoot+Thymeleaf+MyBatis 实现RESTful API
  • 原文地址:https://www.cnblogs.com/shihuc/p/6599170.html
Copyright © 2020-2023  润新知