• RNN-LSTM讲解-基于tensorflow实现


    cnn卷积神经网络在前面已经有所了解了,目前博主也使用它进行了一个图像分类问题,基于kaggle里面的food-101进行的图像识别,识别率有点感人,基于数据集的关系,大致来说还可行。
    下面我就继续学习rnn神经网络。

    rnn神经网络(递归/循环神经网络)模式如下:

    我们在处理文字等问题的时候,我们的输入会把上一个时间输出的数据作为下一个时间的输入数据进行处理。
    例如:我们有一段话,我们将其分词,得到t个数据,我们分别将每一个词传入到x0,x1....xt里面,当x0传入后,会得到一个结果h0,同时我们会将处理后的数据传入到下个时间,到下个时间的时候,我们会再传入一个数据x1,同时还有上一个时间处理后的数据,将这两个数据进行整合计算,然后再向下传输,一直到结束。
    rnn本质来说还是一个bp回路,不过他只是比bp网络多一个环节,即它可以反馈上一时间点处理后的数据。

    上图细化如下:


    rnn实际上还是存在梯度消失的问题,因此如上图所示,当我们在第一个时间输入的数据,可能在很久之后他就已经梯度消失了(影响很小),因此我们使用lstm(long short trem memory)

    上图有三个门:输入门    忘记门   输出门
    1.输入门:通过input * g 来判断是否输入,如果不输入就为0,输入就是0,以此判断信号是否输入
    2.忘记门:这个信号是否需要衰减多少,可能为50%,衰减是根据信号来判断。
    3.输入门:通过判断是否输出,或者输出多少,例如输出50%。
    因此上述图可化为:

    可以看出,这三个门,所有得影响都是关于输入和上一个数据得输出来进行计算的。

    可以看下图:

    我们使用lstm得话,通过三个门决定信号是否向下传输,传输多少都可以控制,是否传入信号,输出信息都进行控制。

    下面我们还是用tensorflow实现,数据集还是手写数字,虽然rnn主要是用在文字和语言上,但是它依旧可以用在图片上。
    下面给出代码:

    ```python
    import tensorflow as tf
    from tensorflow.contrib import rnn
    from tensorflow.examples.tutorials.mnist import  input_data
    mnist=input_data.read_data_sets("MNNIST_data",one_hot=True)
    
    #输入图片为 28*28
    n_inputs=28#输入一行,一行有28个像素
    max_time=28#一共28行,所以为28*28
    lstm_size=100#100个隐藏单元
    batch_size=50
    n_classes=10
    n_batch=mnist.train.num_examples//batch_size#计算一共多少批次
    
    #这里none表示第一个维度可以是任意长度
    x=tf.placeholder(tf.float32,[None,784])
    
    y=tf.placeholder(tf.float32,[None,10])
    
    #初始化权值
    weights=tf.Variable(tf.truncated_normal([lstm_size,n_classes],stddev=0.1))
    #初始化偏置值
    biases=tf.Variable(tf.constant(0.1,shape=[n_classes]))
    
    ##定义Rnn 网络
    def RNN(X,weights,biases):
        inputs=tf.reshape(X,[-1,max_time,n_inputs])
        #定义lstm基本cell
        lstm_cell = rnn.BasicLSTMCell(lstm_size)
        #lstm_cell=tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size)
        outputs,final_state=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
        results=tf.nn.softmax(tf.matmul(final_state[1],weights)+biases)
        return results
    prediction=RNN(x,weights,biases)
    #损失函数
    cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
    #优化器
    train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    #保存结果
    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
    
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    init=tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(6):
            for batch in range(n_batch):
                batch_xs,batch_ys=mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
    
            acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
            print("iter:"+str(epoch)+"testing accuracy"+str(acc))





    ```
    运行结果如下:

  • 相关阅读:
    gin使用validator库参数校验若干实用技巧
    在gin框架中使用JWT
    使用zap接收gin框架默认的日志并配置日志归档
    gin框架路由拆分与注册
    Gin框架介绍及使用
    GO学习-(39) 优雅地关机或重启
    GO学习-(38) Go语言结构体转map[string]interface{}的若干方法
    WPF中不规则窗体与WindowsFormsHost控件的兼容问题完美解决方案
    [ 夜间模式 ] NightVersion
    HDU1518 Square(DFS)
  • 原文地址:https://www.cnblogs.com/lh9527/p/9527-11.html
Copyright © 2020-2023  润新知