• Faster-rnnlm代码分析2


    也就是构造一棵Huffman Tree,输入是按照词汇频次由高到低排序的

    采用层次SoftMax的做法,是为了使得训练和预测时候的softmax输出加速,原有multinomal softmax,是和

    训练词汇量|V|成正比的,而现在由于二叉树的特性,变成了log(|V|),也就是平均每个预测只做log(|V|)次

    binarysoftmax。当然还有另外一种不采用HSTree的方法也就是nce(Noise Contrastive Estimation),后面再分析。

       

    由于</s>会被统计,这里为了模拟这个,用</s>代替了 ""

    [root@cq01-forum-rstree01.cq01.baidu.com faster-rnnlm]# cat shijiebei.txt

    </s> 喜欢 观看 巴西 足球 世界杯

    </s> 喜欢 观看 巴西 足球

    </s> 喜欢 观看 巴西 足球

    </s> 喜欢 观看 巴西

    </s> 喜欢 观看 巴西

    </s> 喜欢 观看

    </s> 喜欢

    喜欢

    [root@cq01-forum-rstree01.cq01.baidu.com faster-rnnlm]# pwd

    /home/users/chenghuige/other/faster-rnnlm.debug/faster-rnnlm

       

    (gdb) p vocab

    $63 = (const Vocabulary &) @0x7fffffffdb60: {static kWordOOV = 4294967295, words_ = std::vector of length 6, capacity 8 = {{freq = 15, word = 0x6ae1c0 "</s>"}, {freq = 8,

    word = 0x6aea30 "317262273266"}, {freq = 6, word = 0x6ae4a0 "271333277264"}, {freq = 5, word = 0x6aead0 "260315316367"}, {freq = 3, word = 0x6aeaf0 "327343307362"}, {freq = 1,

    word = 0x6aeba0 "312300275347261255"}}, hash_impl_ = 0x6ae780}

    (gdb) n

    203         for (int branch = 0; branch < ARITY; ++branch) {

    (gdb) n

    204         if (next_leaf_node >= 0 && weight[next_leaf_node] < weight[next_inner_node]) {

    (gdb) c

    Continuing.

       

    Breakpoint 6, HSTree::CreateHuffmanTree (vocab=..., layer_size=5) at hierarchical_softmax.cc:222

    222         return new HSTree(vocab.size(), layer_size, children);

    (gdb) p children

    $64 = std::vector of length 10, capacity 10 = {5, 4, 6, 3, 2, 1, 7, 8, 0, 9}

    (gdb) p weight

    $65 = std::vector of length 13, capacity 13 = {15, 8, 6, 5, 3, 1, 4, 9, 14, 23, 38, 0, 0}

       

    Breakpoint 7, HSTree::Tree::Tree (this=0x6aedf0, leaf_count=6, children=std::vector of length 10, capacity 10 = {...}) at hierarchical_softmax.cc:165

    165        }

    (gdb) p Tree

    $66 = {void (HSTree::Tree * const, int, const std::vector<int, std::allocator<int> > &)} 0x418718 <HSTree::Tree::Tree(int, std::vector<int, std::allocator<int> > const&)>

    (gdb) p *this

    $67 = {leaf_count_ = 6, root_node_ = 10, tree_height_ = 4, children_ = std::vector of length 10, capacity 10 = {5, 4, 6, 3, 2, 1, 7, 8, 0, 9}, path_lengths_ = std::vector of length 6, capacity 6 = {

    2, 4, 4, 4, 5, 5}, points_ = std::vector of length 246, capacity 246 = {10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

    0, 10, 9, 8, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 9, 8, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 9, 7, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 9, 7, 6,

    4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...}, branches_ = std::vector of length 240, capacity 240 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

    0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...}}

       

    仿照写了一个简化版的 这个看起来更清楚一些

    //主要输出接口是根据wordId来定位 路径上面的node和branch

    class HuffmanTree

    {

    public:

    //默认假定输入的freqs是从大到小排好序的 needSort = false

    //@TODO 对应 NumBranch不是2 比如3叉huffman 要检查叶子数目是否ok

    HuffmanTree(const vector<int64>& freqs, int numBranchces = 2)

    :_numBranches(numBranchces),

    _numLeaves(freqs.size()),

    _numInterNodes(freqs.size() - 1),

    _numNodes(_numLeaves + _numInterNodes),

    _childs(_numBranches, _numInterNodes),

    _root(_numNodes - 1),

    _weight(freqs.begin(), freqs.end()),

    _nodePaths(_numLeaves),

    _branchPaths(_numLeaves)

    {

    Init();

    Build();

    }

       

    int NumBranches() const

    {

    return _numBranches;

    }

       

    bool IsLeaf(int node) const

    {

    return node < _numLeaves;

    }

       

    int Root() const

    {

    return _root;

    }

       

    int NumLeaves() const

    {

    return _numLeaves;

    }

       

    size_t size() const

    {

    return _numLeaves;

    }

       

    const vector<int>& NodePaths(int node) const

    {

    return _nodePaths[node];

    }

       

    const vector<int>& BranchPaths(int node) const

    {

    return _branchPaths[node];

    }

       

    int Height() const

    {

    return _height;

    }

       

    protected:

    private:

    void Init()

    {

    _weight.resize(_numNodes, std::numeric_limits<int64>::max());

    }

       

    int InterNodeIndex(int index) const

    {

    return _numLeaves + index;

    }

       

    void Build()

    {

    int minLeafIndex = _numLeaves - 1;

    int minInterIndex = _numLeaves;

       

    vector<int> parents(_numNodes);

    vector<int> branches(_numNodes);

    //每次选取_numBranches个最小的节点,并将它们的和按照顺序

    for (int interIndex = 0; interIndex < _numInterNodes; interIndex++)

    {

    int64 weight = 0;

    int index = InterNodeIndex(interIndex);

    for (int branch = 0; branch < _numBranches; branch++)

    {

    if (minLeafIndex >= 0 && _weight[minLeafIndex] <= _weight[minInterIndex])

    {

    weight += _weight[minLeafIndex];

    _childs[branch][interIndex] = minLeafIndex;

    parents[minLeafIndex] = index;

    branches[minLeafIndex] = branch;

    minLeafIndex--;

    }

    else

    {

    weight += _weight[minInterIndex];

    _childs[branch][interIndex] = minInterIndex;

    parents[minInterIndex] = index;

    branches[minInterIndex] = branch;

    minInterIndex++;

    }

    }

    _weight[index] = weight;

    }

       

    for (int leafNode = 0; leafNode < _numLeaves; leafNode++)

    {

    for (int node = leafNode; node != _root; node = parents[node])

    {

    _nodePaths[leafNode].push_back(parents[node]);

    _branchPaths[leafNode].push_back(branches[node]);

    }

    if (_branchPaths[leafNode].size() > _height)

    {

    _height = _branchPaths[leafNode].size();

    }

    std::reverse(_nodePaths[leafNode].begin(), _nodePaths[leafNode].end());

    std::reverse(_branchPaths[leafNode].begin(), _branchPaths[leafNode].end());

    }

    }

       

    private:

    int _numBranches = 2;

       

    //@TODO 目前都是只考虑int 后面考虑改为 WordIndex ? unsigned

    int _numLeaves;

    int _numInterNodes;

    int _numNodes;

    int _root;

    int _height = 0;

       

    //存储所有node的weight最开始存储所有叶子节点从大到小排列,然后存储内部节点从小到大排列,root节点就是最后一个节点

    //保持和faster-rnnlm中策略一致

    vector<int64> _weight;

    ufo::Matrix<int> _childs;

       

    //记录每一个叶子节点对应从root到它自身路径的所有节点,不包括它自己

    vector<vector<int> > _nodePaths;

    //记录每一个叶子节点对应从root到它自身的所有路径的branch标记 比如 010100

    vector<vector<int> > _branchPaths;

    };

       

    测试程序

    TEST(huffmantree, func)

    {

    vector<int64> vec = { 15, 8, 6, 5, 3, 1 };

    HuffmanTree tree(vec);

    Pvec(tree._weight);

    for (int i = 0; i < tree.NumBranches(); i++)

    {

    Pval(i);

    Pvec(tree._childs[i]);

    }

       

    for (size_t i = 0; i < tree.size(); i++)

    {

    Pval(i);

    Pvec(tree.NodePaths(i));

    Pvec(tree.BranchPaths(i));

    }

    Pval(tree.Height());

    }

       

    测试结果

    [root@cq01-forum-rstree01.cq01.baidu.com ds]# pwd

    /home/users/chenghuige/rsc/app/search/sep/anti-spam/gezi/test/ds

    [root@cq01-forum-rstree01.cq01.baidu.com ds]# ./test_huffmantree

    [==========] Running 1 test from 1 test case.

    [----------] Global test environment set-up.

    [----------] 1 test from huffmantree

    [ RUN ] huffmantree.func

    I1109 16:06:23.855368 26257 test_huffmantree.cc:32] tree._weight --- 11

    I1109 16:06:23.855466 26257 test_huffmantree.cc:32] 0 15

    I1109 16:06:23.855478 26257 test_huffmantree.cc:32] 1 8

    I1109 16:06:23.855484 26257 test_huffmantree.cc:32] 2 6

    I1109 16:06:23.855489 26257 test_huffmantree.cc:32] 3 5

    I1109 16:06:23.855494 26257 test_huffmantree.cc:32] 4 3

    I1109 16:06:23.855499 26257 test_huffmantree.cc:32] 5 1

    I1109 16:06:23.855504 26257 test_huffmantree.cc:32] 6 4

    I1109 16:06:23.855509 26257 test_huffmantree.cc:32] 7 9

    I1109 16:06:23.855515 26257 test_huffmantree.cc:32] 8 14

    I1109 16:06:23.855520 26257 test_huffmantree.cc:32] 9 23

    I1109 16:06:23.855525 26257 test_huffmantree.cc:32] 10 38

    I1109 16:06:23.855530 26257 test_huffmantree.cc:35] i --- [0]

    I1109 16:06:23.855538 26257 test_huffmantree.cc:36] tree._childs[i] --- 5

    I1109 16:06:23.855543 26257 test_huffmantree.cc:36] 0 5

    I1109 16:06:23.855550 26257 test_huffmantree.cc:36] 1 6

    I1109 16:06:23.855554 26257 test_huffmantree.cc:36] 2 2

    I1109 16:06:23.855559 26257 test_huffmantree.cc:36] 3 7

    I1109 16:06:23.855563 26257 test_huffmantree.cc:36] 4 0

    I1109 16:06:23.855569 26257 test_huffmantree.cc:35] i --- [1]

    I1109 16:06:23.855574 26257 test_huffmantree.cc:36] tree._childs[i] --- 5

    I1109 16:06:23.855579 26257 test_huffmantree.cc:36] 0 4

    I1109 16:06:23.855584 26257 test_huffmantree.cc:36] 1 3

    I1109 16:06:23.855589 26257 test_huffmantree.cc:36] 2 1

    I1109 16:06:23.855594 26257 test_huffmantree.cc:36] 3 8

    I1109 16:06:23.855599 26257 test_huffmantree.cc:36] 4 9

    I1109 16:06:23.855604 26257 test_huffmantree.cc:41] i --- [0]

    I1109 16:06:23.855610 26257 test_huffmantree.cc:42] tree.NodePaths(i) --- 1

    I1109 16:06:23.855615 26257 test_huffmantree.cc:42] 0 10

    I1109 16:06:23.855620 26257 test_huffmantree.cc:43] tree.BranchPaths(i) --- 1

    I1109 16:06:23.855626 26257 test_huffmantree.cc:43] 0 0

    I1109 16:06:23.855631 26257 test_huffmantree.cc:41] i --- [1]

    I1109 16:06:23.855636 26257 test_huffmantree.cc:42] tree.NodePaths(i) --- 3

    I1109 16:06:23.855641 26257 test_huffmantree.cc:42] 0 10

    I1109 16:06:23.855646 26257 test_huffmantree.cc:42] 1 9

    I1109 16:06:23.855651 26257 test_huffmantree.cc:42] 2 8

    I1109 16:06:23.855656 26257 test_huffmantree.cc:43] tree.BranchPaths(i) --- 3

    I1109 16:06:23.855662 26257 test_huffmantree.cc:43] 0 1

    I1109 16:06:23.855667 26257 test_huffmantree.cc:43] 1 1

    I1109 16:06:23.855672 26257 test_huffmantree.cc:43] 2 1

    I1109 16:06:23.855677 26257 test_huffmantree.cc:41] i --- [2]

    I1109 16:06:23.855682 26257 test_huffmantree.cc:42] tree.NodePaths(i) --- 3

    I1109 16:06:23.855687 26257 test_huffmantree.cc:42] 0 10

    I1109 16:06:23.855692 26257 test_huffmantree.cc:42] 1 9

    I1109 16:06:23.855697 26257 test_huffmantree.cc:42] 2 8

    I1109 16:06:23.855702 26257 test_huffmantree.cc:43] tree.BranchPaths(i) --- 3

    I1109 16:06:23.855707 26257 test_huffmantree.cc:43] 0 1

    I1109 16:06:23.855711 26257 test_huffmantree.cc:43] 1 1

    I1109 16:06:23.855716 26257 test_huffmantree.cc:43] 2 0

    I1109 16:06:23.855721 26257 test_huffmantree.cc:41] i --- [3]

    I1109 16:06:23.855726 26257 test_huffmantree.cc:42] tree.NodePaths(i) --- 3

    I1109 16:06:23.855731 26257 test_huffmantree.cc:42] 0 10

    I1109 16:06:23.855736 26257 test_huffmantree.cc:42] 1 9

    I1109 16:06:23.855741 26257 test_huffmantree.cc:42] 2 7

    I1109 16:06:23.855746 26257 test_huffmantree.cc:43] tree.BranchPaths(i) --- 3

    I1109 16:06:23.855751 26257 test_huffmantree.cc:43] 0 1

    I1109 16:06:23.855756 26257 test_huffmantree.cc:43] 1 0

    I1109 16:06:23.855762 26257 test_huffmantree.cc:43] 2 1

    I1109 16:06:23.855767 26257 test_huffmantree.cc:41] i --- [4]

    I1109 16:06:23.855770 26257 test_huffmantree.cc:42] tree.NodePaths(i) --- 4

    I1109 16:06:23.855775 26257 test_huffmantree.cc:42] 0 10

    I1109 16:06:23.855780 26257 test_huffmantree.cc:42] 1 9

    I1109 16:06:23.855785 26257 test_huffmantree.cc:42] 2 7

    I1109 16:06:23.855790 26257 test_huffmantree.cc:42] 3 6

    I1109 16:06:23.855795 26257 test_huffmantree.cc:43] tree.BranchPaths(i) --- 4

    I1109 16:06:23.855800 26257 test_huffmantree.cc:43] 0 1

    I1109 16:06:23.855805 26257 test_huffmantree.cc:43] 1 0

    I1109 16:06:23.855810 26257 test_huffmantree.cc:43] 2 0

    I1109 16:06:23.855814 26257 test_huffmantree.cc:43] 3 1

    I1109 16:06:23.855819 26257 test_huffmantree.cc:41] i --- [5]

    I1109 16:06:23.855824 26257 test_huffmantree.cc:42] tree.NodePaths(i) --- 4

    I1109 16:06:23.855829 26257 test_huffmantree.cc:42] 0 10

    I1109 16:06:23.855834 26257 test_huffmantree.cc:42] 1 9

    I1109 16:06:23.855839 26257 test_huffmantree.cc:42] 2 7

    I1109 16:06:23.855844 26257 test_huffmantree.cc:42] 3 6

    I1109 16:06:23.855849 26257 test_huffmantree.cc:43] tree.BranchPaths(i) --- 4

    I1109 16:06:23.855854 26257 test_huffmantree.cc:43] 0 1

    I1109 16:06:23.855859 26257 test_huffmantree.cc:43] 1 0

    I1109 16:06:23.855865 26257 test_huffmantree.cc:43] 2 0

    I1109 16:06:23.855870 26257 test_huffmantree.cc:43] 3 0

    I1109 16:06:23.855875 26257 test_huffmantree.cc:45] tree.Height() --- [4]

    [ OK ] huffmantree.func (0 assertion, 0 ms)

    [----------] 1 test from huffmantree (0 ms total)

       

    [----------] Global test environment tear-down

    [==========] 1 test from 1 test case ran. (0 assertion total, 0 ms total)

    [ PASSED ] 1 test.

       

  • 相关阅读:
    [bzoj4408][Fjoi2016]神秘数
    BZOJ1102: [POI2007]山峰和山谷Grz
    BZOJ1098: [POI2007]办公楼biu
    BZOJ1097: [POI2007]旅游景点atr
    GDOI2018 新的征程
    BZOJ2084: [Poi2010]Antisymmetry
    回文树详解
    Codeforces739E. Gosha is hunting
    一道题17
    LOJ#6002. 「网络流 24 题」最小路径覆盖
  • 原文地址:https://www.cnblogs.com/rocketfan/p/4950246.html
Copyright © 2020-2023  润新知