• 学到一个编码技巧:用重复写入代替if判断,减少程序分支


    作者:张富春(ahfuzhang),转载时请注明作者和引用链接,谢谢!


    近期阅读了rust标准库的hashbrown库(也就是一个hashmap的实现),并搞了一个中文注释的版本,有兴趣的同学请看:https://github.com/ahfuzhang/rust-hashbrown-v0.12.0
    hashbrown的原理来自google开源的swiss table。我之前写了一篇swiss table的介绍:《Swisstable:C++中比std::unordered_map更快的hash表》。

    hashbrown中,hash表的冲突管理是通过一个与buckets一样长的ctrl byte数组来实现的,每个桶的位置被占用后,就会把对应下标的ctrl byte写为key的高7bit。实现的代码如下:

        /// Sets a control byte, and possibly also the replicated control byte at
        /// the end of the array.
        #[inline]
        unsafe fn set_ctrl(&self, index: usize, ctrl: u8) {
            // Replicate the first Group::WIDTH control bytes at the end of
            // the array without using a branch:
            // - If index >= Group::WIDTH then index == index2.
            // - Otherwise index2 == self.bucket_mask + 1 + index.
            //
            // The very last replicated control byte is never actually read because
            // we mask the initial index for unaligned loads, but we write it
            // anyways because it makes the set_ctrl implementation simpler.
            //
            // If there are fewer buckets than Group::WIDTH then this code will
            // replicate the buckets at the end of the trailing group. For example
            // with 2 buckets and a group size of 4, the control bytes will look
            // like this:
            //
            //     Real    |             Replicated
            // ---------------------------------------------
            // | [A] | [B] | [EMPTY] | [EMPTY] | [A] | [B] |
            // ---------------------------------------------
            let index2 = ((index.wrapping_sub(Group::WIDTH)) & self.bucket_mask) + Group::WIDTH;
    
            *self.ctrl(index) = ctrl;
            *self.ctrl(index2) = ctrl;  //index和index2几乎是一样,这里为什么要重复再写一次???
        }
    

    代码中:

    • index是需要写入的ctrl byte数组的下标
    • Group::WIDTH为16字节
    • index.wrapping_sub(Group::WIDTH))是在无符号整数上做二进制减法。
      • rust中存在会溢出的减法,一定要使用wrapping_sub(),否则会panic。我在这里做了个实验。
    • *self.ctrl(index) = ctrl;这行代码,对数组中指定下标的ctrl byte进行赋值
    • 疑惑的是这行:*self.ctrl(index2) = ctrl;
      • 当index>=16时,index和index2的值完全相等,重复赋值,看似是没有意义的;
      • 当index<16是,index2会出现在ctrl byte数组后的0~15字节中,是超过了ctrl byte数组范围的;
      • 用python代码实验一下:
        • 假定桶的长度为1024(hashbrown中,桶的长度一定是2的幂),则bucket_mask等于1023,也就是b01111_11111
        • 假设index为2,则: (2 - 16) & (2**10-1) + 16 = 1026
        • 1026指向了ctrl byte数组尾部的第三个下标

    看不懂的时候再认真读读注释:把ctrl byte数组的第一个Group复制到最后,从而避免使用分支!


    现在,我们回到最初,从头开始解释这个写法:

    • hashbrown的桶长度必须是2的幂,假定此处是1024个
      • 分配KV数据的结构可以表示为:
    struct hashbrown{
        struct {
           KEY_TYPE key;
           VALUE_TYPE value;
        } buckets[1024];
    }
    
    • hashbrown采用相邻地址法来解决hash冲突,因此需要分配一个与桶长度一致的ctrl byte数组:
      • 每16个ctrl byte成为一个Group,使用SSE的指令能够一次搜索16字节,可以提升性能
      • 在分配的时候,在ctrl byte数组的尾部,再多分配16字节。这16字节就是为了复制ctrl byte数组头部的16字节。
      • ctrl byte数组,包含其后的16字节,都只为0x80,即 b1000_0000,最高位为1说明这个位置未使用。
      • Ctrl byte数组的内容可以表示为:
    struct hashbrown{
        struct {
           KEY_TYPE key;
           VALUE_TYPE value;
        } buckets[1024];
        byte ctrls[1024+16];
    }
    
    • 为什么要把前16个ctrl byte复制在数组末位之后呢?这里涉及hashmap在插入时候搜索空桶的逻辑:

      • 每次根据KEY计算出一个64位的hashcode
      • hashcode 取模桶的长度得到了桶的下标
      • 如果这个位置未被占用,则使用这个位置,并把ctrls数组中的对应下标写为hashcode的高7bit
      • 如果这个位置被占用了,则需要从相邻的位置去寻找空位。
    • hashbrown(或者说swiss table)的精彩之处就在于相邻位置的查找:

      • ctrls数组中连续的16字节(128bit)称为一个Group
      • 搜索的时候,把128bit加载到SSE的寄存器
      • 通过SSE指令可以一次性判断16字节的内容是否有空位
    • 下面是搜索插入位置的代码:

        /// Searches for an empty or deleted bucket which is suitable for inserting
        /// a new element.
        ///
        /// There must be at least 1 empty bucket in the table.
        #[inline]
        fn find_insert_slot(&self, hash: u64) -> usize {
            let mut probe_seq = self.probe_seq(hash);  //构造ProbeSeq对象,进行三角数跳跃(第一次跳跃1个,第二次(在上一次基础上)跳跃2个,第三次跳跃3个……)
            loop {
                unsafe {  //当这一字节处于整个ctrl数组的边缘的时候,就必须在最后加一个Group,以此避免溢出
                    let group = Group::load(self.ctrl(probe_seq.pos));  //加载当前Group
                    if let Some(bit) = group.match_empty_or_deleted().lowest_set_bit() {  //当前group找个空位
                        let result = (probe_seq.pos + bit) & self.bucket_mask;  //这个就是找到的插入位置
    
                        // In tables smaller than the group width, trailing control
                        // bytes outside the range of the table are filled with
                        // EMPTY entries. These will unfortunately trigger a
                        // match, but once masked may point to a full bucket that
                        // is already occupied. We detect this situation here and
                        // perform a second scan starting at the beginning of the
                        // table. This second scan is guaranteed to find an empty
                        // slot (due to the load factor) before hitting the trailing
                        // control bytes (containing EMPTY).
                        if unlikely(is_full(*self.ctrl(result))) { //当桶的长度小于16时,触发这里的逻辑
                            debug_assert!(self.bucket_mask < Group::WIDTH);
                            debug_assert_ne!(probe_seq.pos, 0);
                            return Group::load_aligned(self.ctrl(0))
                                .match_empty_or_deleted()
                                .lowest_set_bit_nonzero();
                        }
    
                        return result;
                    }
                }
                probe_seq.move_next(self.bucket_mask);  //找不到空位,就跳跃到下个三角数
            }
        }  //最坏的情况,这个函数会遍历整个的ctrls数组,直到找到空位。不过上层函数保障了一定有空位
    
    • 以上的代码解释了为什么要在ctrls数组的最后多加16个字节:假设桶长度为1024,假设当前开始load group的下标为 (1024-1)-16+1=1008,这个位置距离ctrls的末位不足一个Group。
    • 一般性的思维就是:加个if判断,处于边界的时候特殊处理。而作者则是多分配了一个Group在尾部,使得按照Group加载的时候,一定不会溢出。
    • 同时,每次写入ctrls数组时,前16个ctrl byte总是被复制到了最后16字节的溢出区;这样,从末位加载的Group起始包含了回绕到头部的ctrl byte的信息。作者通过复制,来解决了搜索时候的回绕,且不用if语句来做特殊判断。
    • 总结一下:
      • 为了解决hash冲突,使用了额外的ctrl byte数组来表示buckets的占用情况;
      • 为了高效的搜索桶的占用情况,使用了以Group为单位的搜索,通过SSE指令一次搜索16个位置;
      • 为了解决Group加载在边缘位置可能溢出的问题,使用了额外的16字节来作为溢出区,避免了用if去判断;
      • 当搜索到末尾再回绕到头部搜索的时候,ctrls数组的前16字节在写入时就会复制到末位的溢出区;这样,回绕的时候,尾部的bit等于头部bit的内容;仍然也不需要if进行边界条件的判断。
      • 重复写入一个byte是一条指令,用if语句判断边界条件也是一条指令。相比之下,大多数时候的重复写入代替了if语句,成本上与直接使用if一致,并没有浪费;并且,替代了if语句,使得不需要CPU做分支预测等工作,理论上能够提升指令cache的命中率,并提升性能。
  • 相关阅读:
    ThreadLocal的原理和使用场景
    sleep wait yield join方法的区别
    GC如何判断对象可以被回收
    双亲委派
    ConcurrentHashMap原理,jdk7和jdk8版本
    hashCode和equals
    接口和抽象类区别
    为什么局部内部类和匿名内部类只能访问局部final变量
    【GAN】基础GAN代码解析
    TF相关codna常用命令整理
  • 原文地址:https://www.cnblogs.com/ahfuzhang/p/16149823.html
Copyright © 2020-2023  润新知