• char-rnn-tensorflow源码解析及结构流程分析


    char-rnn-tensorflow由飞飞弟子karpathy编写,展示了如何用tensorflow来搭建一个基本的RNN(LSTM)网络,并进行基于char的seq2seq进行训练。

    数据读取部分

    data文件夹下的input.txt为示例用的莎翁剧本,在数据读取阶段的preprocess函数中,将基于该文本文件生产词汇表文件vocal.pkl(记录词的索引)和data.npy(将训练用的文字转换成索引的文件)。

    其中,self.vocab就是组织的字典文件,给出任意char能查询到它的index,如[(3:c),(25:y)....]。如果是第一次读取,Self.char(索引对应字符)被dump到vocab.pkl文件中。读取的文字素材内容也通过查字典文件的方式被

    转成索引队列,存入tensor变量和data.npy文件中。

    接下来确定每一批训练用的Batch的数据:

    最终得到一个epoch需要训练的batch数量,训练数据x_batch和y_batch(y_batch[i]存储了x_batch[i+1]的char,并在结尾处循环),两者都是长度为num_batch的list,每一个item是一个batch需要训练的数据tensor(或标签数据),

    即batch_size组数据(上图的第一个50),每一组数据有seq_length个word(上图的第二个50)。

    model的建立

    这里使用了rnn.BasicLSTMCell,也就是基本的LSTM来构建cell,隐层单元数为args.rnn_size(默认128个),一个cell中的layer层数为2.隐层单元的个数也就对应了训练出来的词的dense-vector的维度,隐层单元的矩阵类似于

    word2vec中训练用的隐层矩阵。模型中的placeholder参数self.input_data和self.target的shape为[args.batch_size, args.seq_length],经过embedding查找后,被转化为稠密向量inputs(长度为batch_size的list,每一个item的

    shape为[args.seq_length,args.rnn_size],如上所述,args.rnn_size即是dense-vector的维度,每一个word由dense-vector维度的词向量来表示)

    关于embedding矩阵的构建,可以参考下图:

    上例将embedding矩阵初始化为one-hot编码的对角矩阵,如果不进行初始化(就像char-rnn例子里面一样),则数据会在initial的时候被随机初始化,如下图:

     最终,input会被压成一个长度为num_steps的列表,每个元素是[batch_size, input_size]的2-D维的tensor

    loop函数,沿用模型的参数(w和b)来循环生成下一个词。训练的时候 inputs里面已经有完整的训练sample了,所以loop函数被设置为null,不使用,inference的时候这部分是缺失的,需要我们用loop来生成。

    MultiRNNCell函数构造了一个时间步的多层rnn

    legacy_seq2seq.rnn_decoder负责实现将其循环num_steps个时间步,这里的num_step就等价于seq_length。

    这里附上rnn_decoder的伪代码

    接下来就是手动求导 手动优化了:

    训练完成后,Sample的流程总体上中规中矩,每次喂一个char给模型,吐出一个char,重复设定的num次。不过比较奇怪的是最后一步,在chars_size维的结果向量中选取最大概率的索引时,使用了一个奇怪的函数,

    weighted_pick,这个函数的最终输出结果和随机数有关,这个随机性和系统的结果有什么关联,看不懂。经高手提醒,想明白了,这里没有直接使用softmax,而是在保证概率优先的前提下,加入了一些随机性,
    让一个随机数落入weight组成的区间中去。

  • 相关阅读:
    Qt:移动无边框窗体(使用Windows的SendMessage)
    github atom 试用
    ENode框架Conference案例转载
    技术
    NET 领域驱动设计实战系列总结
    mac 配置Python集成开发环境
    User、Role、Permission数据库设计ABP
    Oracle 树操作
    Oracle 用户权限管理方法
    Web Api 2, Oracle and Entity Framework
  • 原文地址:https://www.cnblogs.com/punkcure/p/8240989.html
Copyright © 2020-2023  润新知