• TFboy养成记 CNN


    1/先解释下CNN的过程:

    首先对一张图片进行卷积,可以有多个卷积核,卷积过后,对每一卷积核对应一个chanel,也就是一张新的图片,图片尺寸可能会变小也可能会不变,然后对这个chanel进行一些pooling操作。

    最后pooling输出完成,这个算作一个卷积层。

    最后对最后一个pooling结果进行一个简单的MLP的判别其就好了

    2.代码分步:

    2.1 W and bias:注意不要将一些W设为0,一定要注意,这个会在后面一些地方讲到

    1 #注意不要将一些W设为0,一定要注意,这个会在后面一些地方讲到
    2 def getWeights(shape):
    3     return tf.Variable(tf.truncated_normal(shape,stddev= 0.1))
    4 def getBias(shape):
    5     return tf.Variable(tf.constant(0.1))
    View Code

    2.2 卷积层操作:

    首先说下tf.nn.conv2d这个函数:

    其中官方解释:

    这里主要需要了解的是strides的含义:其shape表示的是[batch, in_height, in_width, in_channels]。需要注意的是,看我们在Weights初始化时的shape,我们自己定义的shape格式是[h,w,inchanel,outchanel]   --->chanel也就是我们理解的厚度。

    1 def conv2d(x,W):
    2     return tf.nn.conv2d(x,W,strides = [1,1,1,1],padding="SAME")
    3 #ksize
    4 def maxpooling(x):
    5     return tf.nn.max_pool(x,ksize=[1,2,2,1],strides = [1,2,2,1],padding= "SAME")
    View Code

     关于data_format

    padding也有两种方式:

    其他地方其实也没有什么新操作所有代码在下面:

     1 # -*- coding: utf-8 -*-
     2 """
     3 Spyder Editor
     4 
     5 This is a temporary script file.
     6 """
     7 from tensorflow.examples.tutorials.mnist import input_data
     8 import tensorflow as tf
     9 import numpy as np
    10 #注意不要将一些W设为0,一定要注意,这个会在后面一些地方讲到
    11 def getWeights(shape):
    12     return tf.Variable(tf.truncated_normal(shape,stddev= 0.1))
    13 def getBias(shape):
    14     return tf.Variable(tf.constant(0.1))
    15 #构造卷积层 strides前一个跟最后后一个为1,其他表示方向,padding一般是有两种方式 ,一个是SAME还有一个是VALID
    16 #前者卷积后不改变大小后一个卷积后一般会变小
    17 #strides--->data_format:data_format: An optional string from: "NHWC", "NCHW". Defaults to "NHWC". Specify the data format of the input and output data. With the default format "NHWC", the data is stored in the order of: [batch, height, width, channels]. Alternatively, the format could be "NCHW", the data storage order of: [batch, channels, height, width].
    18 #
    19 def conv2d(x,W):
    20     return tf.nn.conv2d(x,W,strides = [1,1,1,1],padding="SAME")
    21 #ksize
    22 def maxpooling(x):
    23     return tf.nn.max_pool(x,ksize=[1,2,2,1],strides = [1,2,2,1],padding= "SAME")
    24 def compute_acc(v_xs,v_ys):
    25     global predict
    26     y_pre = sess.run(predict,feed_dict = {xs:v_xs,keep_prob:1})
    27     tmp = tf.equal(tf.arg_max(y_pre,1),tf.arg_max(v_ys,1))
    28     accuracy = tf.reduce_mean(tf.cast(tmp,tf.float32))
    29     return sess.run(accuracy,feed_dict = {xs:v_xs,ys:v_ys,keep_prob:1})
    30     
    31     
    32 mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    33 xs = tf.placeholder(tf.float32,[None,28*28])
    34 ys = tf.placeholder(tf.float32,[None,10])
    35 keep_prob = tf.placeholder(tf.float32)
    36 
    37 x_images = tf.reshape(xs,[-1,28,28,1])
    38 
    39 W_c1 = getWeights([5,5,1,32])
    40 b_c1 = getBias([32])
    41 h_c1 = tf.nn.relu(conv2d(x_images,W_c1)+b_c1)
    42 h_p1 = maxpooling(h_c1)
    43 
    44 W_c2 = getWeights([5,5,32,64])
    45 b_c2 = getBias([64])
    46 h_c2 = tf.nn.relu(conv2d(h_p1,W_c2)+b_c2)
    47 h_p2 = maxpooling(h_c2)
    48 
    49 W_fc1 = getWeights([7*7*64,1024])
    50 b_fc1 = getBias([1024])
    51 h_flat = tf.reshape(h_p2,[-1,7*7*64])
    52 h_fc1 = tf.nn.relu(tf.matmul(h_flat,W_fc1)+b_fc1)
    53 h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)
    54 
    55 W_fc2 = getWeights([1024,10])
    56 b_fc2 = getBias([10])
    57 predict = tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)
    58 
    59 loss = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(predict),
    60                                      reduction_indices=[1]))
    61 train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
    62 
    63 sess = tf.Session()
    64 sess.run(tf.initialize_all_variables())
    65 for i in range(1000):
    66     batch_xs, batch_ys = mnist.train.next_batch(100)
    67     sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})
    68     if i % 50 == 0:
    69         print (compute_acc(mnist.test.images,mnist.test.labels))
    View Code

    需要注意的是nn.dropout()

  • 相关阅读:
    git remote: Support for password authentication was removed on August 13, 2021
    win10 安装vue 详解包括node.js、npm、webpack
    solr window 安装与启动
    solr 创建 core
    idea 创建 springboot 模块报错解决
    c# 设计模式篇
    javascript(DHTML)代码和客户端应用程序代码之间实现双向通信.
    委托,匿名方法,Lambda 表达式 的关系
    使用泛型实现单例模式提供者
    asp.net 文件编码问题
  • 原文地址:https://www.cnblogs.com/silence-tommy/p/7110272.html
Copyright © 2020-2023  润新知