• nnet3的代码分析


    nnet3/nnet-common.h

    定义了Index(n, t, x)三元组,表示第nbatch中第t帧。

    并声明了关于IndexCindex的一些读写操作。

       

    nnet3/nnet-nnet.h

    声明了NetworkNode(主要包含其类型以及索引信息)

    声明了Nnetnnet3网络类)

    private:

    //网络中的组件名列表

    std::vector<std::string> component_names_;

    //网络中实际的组件指针列表,同一组件可能出现多次

    std::vector<Component*> components_;

    //网络中结点名列表,即:inputscomponents以及outputs

    //同一组件名会出现两次:foo-inputfoo

    //因为foo-input有其自己的NetworkNode索引

    std::vector<std::string> node_names_;

    //网络中实际的结点指针列表

    std::vector<NetworkNode> nodes_;

    以及关于以上数据成员的实用函数。

       

    nnet3/nnet-component-itf.h

    Componentitfinterface,接口)

    class Component

    主要包含以下函数:

    Propagate //正向传播

    Backprop //反向传播

    StoreStats //储存平均激活值、非线性函数微分平均值

    ZeroStats //stats清零

    GetInputIndexes //只适用于非简单组件

    IsComputable //只适用于非简单组件

    ReorderIndexes //只适用于非简单组件

    以及实用函数

    class RandomComponent: public Component

    随机数生成的组件

    class UpdatableComponent: public Component

    参数扰动率

    学习率

    学习率因子

    实际学习率(实际学习率=学习率*学习率因子)

    冻结自然梯度更新

    每个minibatch最大参数变换率(NnetTrainerL2正则化的形式使用)

    标准L2正则化参数

    的设定、修改、查询

    class NonlinearComponent: public Component

    由于该类不修改特征维数,因子该类是sigmoidsoftmaxReLU的基类

    该类

    储存激活平均值

    储存训练中的微分

    负责模型初始化

    负责IO

    nnet3/nnet-simple-component.h

    class PnormComponent: public Component

    p-norm的公式:

    对维数为intput_dim的输入进行降维,输出维数为output_dim

    PropagateBackprop函数十分简单,具体关于p-norm单元的实现位于

    kaldi::CuMatrixBase::GroupPnorm

    Kaldi::CuMatrixBase::DiffGroupPnorm

    class DropoutComponent : public RandomComponent

    DropoutComponent组件对输入以dropout比例随机置零,而梯度只在非零的输入处进行反向传播。通常只在训练期间使用此组件,但不在测试时间使用

    Dropout: A Simple Way to Prevent Neural Networks from Overfitting"

       

    Propogate()

    //初始化一个元素取值范围为[0,1]的向量y

    const_cast<CuRand<BaseFloat>&>(random_generator_).RandUniform(out);

    out->Add(-dropout);

    out->ApplyHeaviside();

    out->MulElements(in);

       

    通过设置dropout_per_frame_,可以以帧的元素为单位dropout:

    [[0,1,1,1],[1,0,1,1],[1,1,0,1],[1,1,1,0],[1,1,1,0]]

    或帧为单位进行随机丢弃:

    [[1,1,1,1],[0,0,0,0],[0,0,0,0],[1,1,1,1],[0,0,0,0]]

    class ElementwiseProductComponent: public Component

    点乘组件,用于降维

    对于10维输入向量

    (0.7,0.5,1.0,0.2,0.9,0.0,0.3,0.1,0.6,0.8)

    假设输出维数为5,则10/5=2,两两相乘:

    (0.7*0.5,1.0*0.2,0.9*0.0,0.3*0.1,0.6*0.8)

    结果为

    (0.35,0.20,0.0,0.03,0.48)

    class SigmoidComponent: public NonlinearComponent

       

    class TanhComponent: public NonlinearComponent

       

    class RectifiedLinearComponent: public NonlinearComponent

       

    class AffineComponent: public UpdatableComponent

       

    class BlockAffineComponent : public UpdatableComponent

       

    class RepeatedAffineComponent: public UpdatableComponent

       

    class NaturalGradientRepeatedAffineComponent: public RepeatedAffineComponent

       

    class SoftmaxComponent: public NonlinearComponent

    Softmax损失函数(归一化指数函数):

    其中o是输出向量

    Backprop()

    对于softmax函数的微分,令:

    该函数的雅可比矩阵为:

    令输出向量微分为e,输入向量微分为d,有:

       

    nnet3/nnet-computation.h

    负责实际的计算。

    声明了ComputationRequestCommandTypeNnetComputation等类。

    struct ComputationRequest

    //计算需要的输入

    std::vector<IoSpecification> inputs;

    //计算预期的输出

    std::vector<IoSpecification> outputs;

    以及关于以上数据成员的实用函数

    enum CommandType

    神经网络计算类型,如:

    kPropagate

    kBackprop

    kAllocMatrix

    struct NnetComputation

    编译后的神经网络具体计算特定步骤

    给定NnetComputationRequest

    就可编译得到该结构体

    数据成员包括:

    (子)矩阵信息及其索引(使用索引而不存储实际的矩阵)

    矩阵

    计算类型(CommandType

    计算所依赖的输入输出Index

    nnet3/nnet-analyze.h

    检测计算是否能有效进行。

    主要的类:

    class ComputationAnalysis

    private:

    const NnetComputation &computation_;

    const Analyzer &analyzer_;

    ComputationVariables variables;

    std::vector<CommandAttributes> command_attributes;

    std::vector<std::vector<Access> > variable_accesses;

    std::vector<MatrixAccesses> matrix_accesses;

    成员函数:

    访问索引s的第一个非初始化指令

    访问索引s的第一个指令

    访问索引s的最后一个指令

    访问索引s的最后一个写指令

    访问索引s的无效指令

    访问矩阵索引m的第一个非初始化指令

    访问矩阵索引m的最后一个指令

    class ComputationChecker

    ComputationAnalysis类似

    主要检测:

    维数一致性检测

    未定义变量读取检测

    读写冲突检测(是否是写完再读)

    矩阵访问有效性检测

    矩阵压缩检测

    nnet3/nnet-example.h

    struct NnetIo

    std::vector<Index> indexes;

    GeneralMatrix features;

    特征(以及后验)的读写

    struct NnetExample

    //minibatch结构体

    std::vector<NnetIo> io;

    及其实用函数

    以及一些关于NnetExample的比较、哈希等函数

       

       

  • 相关阅读:
    git 无法提交到远程服务器【转载】
    vscode 常用快捷键
    mongodb nodejs一个有自增id的功能
    C++ lambda表达式与函数对象
    TypeScript的async, await, promise,多参数的调用比较(第2篇)
    了解TypeScript的async,await,promise(第1篇)
    TyepScript判断一个变量是null, or undefined
    MongoClient 对 Mongodb的 增删改查 操作
    TypeScript第一个Promise程序
    C++基类的继承和多态
  • 原文地址:https://www.cnblogs.com/JarvanWang/p/9152625.html
Copyright © 2020-2023  润新知