• Java BigInteger中的oddModPow


    java大数类的模幂计算中,其中比较核心部分就是使用一种窗口算法来优化乘法,然后使用Montgomery取模乘的办法加速取模。

    在jdk中,它的源码如下:

    private BigInteger oddModPow(BigInteger y, BigInteger z) {
        /*
         * The algorithm is adapted from Colin Plumb's C library.
         *
         * The window algorithm:
         * The idea is to keep a running product of b1 = n^(high-order bits of exp)
         * and then keep appending exponent bits to it.  The following patterns
         * apply to a 3-bit window (k = 3):
         * To append   0: square
         * To append   1: square, multiply by n^1
         * To append  10: square, multiply by n^1, square
         * To append  11: square, square, multiply by n^3
         * To append 100: square, multiply by n^1, square, square
         * To append 101: square, square, square, multiply by n^5
         * To append 110: square, square, multiply by n^3, square
         * To append 111: square, square, square, multiply by n^7
         *
         * Since each pattern involves only one multiply, the longer the pattern
         * the better, except that a 0 (no multiplies) can be appended directly.
         * We precompute a table of odd powers of n, up to 2^k, and can then
         * multiply k bits of exponent at a time.  Actually, assuming random
         * exponents, there is on average one zero bit between needs to
         * multiply (1/2 of the time there's none, 1/4 of the time there's 1,
         * 1/8 of the time, there's 2, 1/32 of the time, there's 3, etc.), so
         * you have to do one multiply per k+1 bits of exponent.
         *
         * The loop walks down the exponent, squaring the result buffer as
         * it goes.  There is a wbits+1 bit lookahead buffer, buf, that is
         * filled with the upcoming exponent bits.  (What is read after the
         * end of the exponent is unimportant, but it is filled with zero here.)
         * When the most-significant bit of this buffer becomes set, i.e.
         * (buf & tblmask) != 0, we have to decide what pattern to multiply
         * by, and when to do it.  We decide, remember to do it in future
         * after a suitable number of squarings have passed (e.g. a pattern
         * of "100" in the buffer requires that we multiply by n^1 immediately;
         * a pattern of "110" calls for multiplying by n^3 after one more
         * squaring), clear the buffer, and continue.
         *
         * When we start, there is one more optimization: the result buffer
         * is implcitly one, so squaring it or multiplying by it can be
         * optimized away.  Further, if we start with a pattern like "100"
         * in the lookahead window, rather than placing n into the buffer
         * and then starting to square it, we have already computed n^2
         * to compute the odd-powers table, so we can place that into
         * the buffer and save a squaring.
         *
         * This means that if you have a k-bit window, to compute n^z,
         * where z is the high k bits of the exponent, 1/2 of the time
         * it requires no squarings.  1/4 of the time, it requires 1
         * squaring, ... 1/2^(k-1) of the time, it reqires k-2 squarings.
         * And the remaining 1/2^(k-1) of the time, the top k bits are a
         * 1 followed by k-1 0 bits, so it again only requires k-2
         * squarings, not k-1.  The average of these is 1.  Add that
         * to the one squaring we have to do to compute the table,
         * and you'll see that a k-bit window saves k-2 squarings
         * as well as reducing the multiplies.  (It actually doesn't
         * hurt in the case k = 1, either.)
         */
            // Special case for exponent of one
            if (y.equals(ONE))
                return this;
     
            // Special case for base of zero
            if (signum == 0)
                return ZERO;
     
            int[] base = mag.clone();
            int[] exp = y.mag;
            int[] mod = z.mag;
            int modLen = mod.length;
     
            // Make modLen even. It is conventional to use a cryptographic
            // modulus that is 512, 768, 1024, or 2048 bits, so this code
            // will not normally be executed. However, it is necessary for
            // the correct functioning of the HotSpot intrinsics.
            if ((modLen & 1) != 0) {
                int[] x = new int[modLen + 1];
                System.arraycopy(mod, 0, x, 1, modLen);
                mod = x;
                modLen++;
            }
     
            // Select an appropriate window size
            int wbits = 0;
            int ebits = bitLength(exp, exp.length);
            // if exponent is 65537 (0x10001), use minimum window size
            if ((ebits != 17) || (exp[0] != 65537)) {
                while (ebits > bnExpModThreshTable[wbits]) {
                    wbits++;
                }
            }
     
            // Calculate appropriate table size
            int tblmask = 1 << wbits;
     
            // Allocate table for precomputed odd powers of base in Montgomery form
            int[][] table = new int[tblmask][];
            for (int i=0; i < tblmask; i++)
                table[i] = new int[modLen];
     
            // Compute the modular inverse of the least significant 64-bit
            // digit of the modulus
            long n0 = (mod[modLen-1] & LONG_MASK) + ((mod[modLen-2] & LONG_MASK) << 32);
            long inv = -MutableBigInteger.inverseMod64(n0);
     
            // Convert base to Montgomery form
            int[] a = leftShift(base, base.length, modLen << 5);
     
            MutableBigInteger q = new MutableBigInteger(),
                              a2 = new MutableBigInteger(a),
                              b2 = new MutableBigInteger(mod);
            b2.normalize(); // MutableBigInteger.divide() assumes that its
                            // divisor is in normal form.
     
            MutableBigInteger r= a2.divide(b2, q);
            table[0] = r.toIntArray();
     
            // Pad table[0] with leading zeros so its length is at least modLen
            if (table[0].length < modLen) {
               int offset = modLen - table[0].length;
               int[] t2 = new int[modLen];
               System.arraycopy(table[0], 0, t2, offset, table[0].length);
               table[0] = t2;
            }
     
            // Set b to the square of the base
            int[] b = montgomerySquare(table[0], mod, modLen, inv, null);
     
            // Set t to high half of b
            int[] t = Arrays.copyOf(b, modLen);
     
            // Fill in the table with odd powers of the base
            for (int i=1; i < tblmask; i++) {
                table[i] = montgomeryMultiply(t, table[i-1], mod, modLen, inv, null);
            }
     
            // Pre load the window that slides over the exponent
            int bitpos = 1 << ((ebits-1) & (32-1));
     
            int buf = 0;
            int elen = exp.length;
            int eIndex = 0;
            for (int i = 0; i <= wbits; i++) {
                buf = (buf << 1) | (((exp[eIndex] & bitpos) != 0)?1:0);
                bitpos >>>= 1;
                if (bitpos == 0) {
                    eIndex++;
                    bitpos = 1 << (32-1);
                    elen--;
                }
            }
     
            int multpos = ebits;
     
            // The first iteration, which is hoisted out of the main loop
            ebits--;
            boolean isone = true;
     
            multpos = ebits - wbits;
            while ((buf & 1) == 0) {
                buf >>>= 1;
                multpos++;
            }
     
            int[] mult = table[buf >>> 1];
     
            buf = 0;
            if (multpos == ebits)
                isone = false;
     
            // The main loop
            while (true) {
                ebits--;
                // Advance the window
                buf <<= 1;
     
                if (elen != 0) {
                    buf |= ((exp[eIndex] & bitpos) != 0) ? 1 : 0;
                    bitpos >>>= 1;
                    if (bitpos == 0) {
                        eIndex++;
                        bitpos = 1 << (32-1);
                        elen--;
                    }
                }
     
                // Examine the window for pending multiplies
                if ((buf & tblmask) != 0) {
                    multpos = ebits - wbits;
                    while ((buf & 1) == 0) {
                        buf >>>= 1;
                        multpos++;
                    }
                    mult = table[buf >>> 1];
                    buf = 0;
                }
     
                // Perform multiply
                if (ebits == multpos) {
                    if (isone) {
                        b = mult.clone();
                        isone = false;
                    } else {
                        t = b;
                        a = montgomeryMultiply(t, mult, mod, modLen, inv, a);
                        t = a; a = b; b = t;
                    }
                }
     
                // Check if done
                if (ebits == 0)
                    break;
     
                // Square the input
                if (!isone) {
                    t = b;
                    a = montgomerySquare(t, mod, modLen, inv, a);
                    t = a; a = b; b = t;
                }
            }
     
            // Convert result out of Montgomery form and return
            int[] t2 = new int[2*modLen];
            System.arraycopy(b, 0, t2, modLen, modLen);
     
            b = montReduce(t2, mod, modLen, (int)inv);
     
            t2 = Arrays.copyOf(b, modLen);
     
            return new BigInteger(1, t2);
        }
    View Code

    首先需要介绍一下它对于指数运算的手段:

    * The idea is to keep a running product of b1 = n^(high-order bits of exp)
    * and then keep appending exponent bits to it.  The following patterns
    * apply to a 3-bit window (k = 3):
    * To append   0: square
    * To append   1: square, multiply by n^1
    * To append  10: square, multiply by n^1, square
    * To append  11: square, square, multiply by n^3
    * To append 100: square, multiply by n^1, square, square
    * To append 101: square, square, square, multiply by n^5
    * To append 110: square, square, multiply by n^3, square
    * To append 111: square, square, square, multiply by n^7
    View Code

    这段话的意思是,对于n^exp这个幂运算,在把exp转成二进制后,从高位开始向低位每次取大小为k的窗口,注释以k=3为例,计算也是从高位开始向低位二进制计算,例如高位是0,那么就相当于把指数/2,因此底数平方,高位是1,相当于平方后还要处理一个指数+1,因此还要乘一个n^1。以此类推,对于11来讲,为了保证尽可能平方而少乘法,因此可以把相邻的平方合并进去,因此本来应该是平方,乘n^1,平方,乘n^1,被化简成了平方,平方,乘n^3(因为高位n^1每向低位移动一位相当于多一个2), 后面的同理。

    那么,为什么要这么处理呢,后面的注释提到了,这是为了尽可能地减少乘法次数。先对n的奇数次幂的结果打表,那么如此处理后,对于每一个窗口,至多只会乘一次n^i,而i一定是奇数。因此,对于每次大小为k的窗口的计算,一共只需要k次平方和一次乘法。相比传统快速幂而言,少了若干次乘法(传统快速幂每次遇到1就要乘一次)。

    那么,窗口是不是越大越好呢?それはどうかな!因为如果窗口取的越大,那么打的表也会越大,因为对于这个奇数次幂的表来讲,它的最大值是n的2^k-1次幂,大小为2^(k-1),因此如果k太大,虽然窗口次数变少了,但是打表的开销也会变大,而根据简单的概率学,需要乘n^1的概率是1/2,n^3是1/4以此类推,因此实际上越远的幂积被使用的次数越少,但其实每个值都可能被用到,所以表并不浪费。因此需要考虑的是取窗口开销的影响,对于纪录大小的mag来讲,单独一位最大是Integer.MAX_VALUE,而据说对单独一位(int的32位)进行计算,jvm会有内部函数优化,因此窗口最大取到Integer.MAX_VALUE就足够了。

    它的取模计算方式也使用了一种优化的方案,叫Montgomery Multiply(蒙哥马利模乘)。因为传统的取模方式是除法,而除法在计算机中计算是非常耗时的,因此蒙哥马利模乘的方式是将取模转化成位运算,这样会大大加速取模的速度,达到优化的效果。那么转化的方式就是通过改变取模数来实现的。通过Montgomery 约简的方式使得取模数变成2的正整数次幂,而又保证了被取模的数在取模后不会有位数损失。这样保证了只要在Montgomery 域内,每次取模都变成了算术右移,付出的代价仅仅是开始幂乘的时候和幂乘结束后各需要进行一次从一般域向Montgomery域的运算或逆运算,但每次取模的速度都被大大提高了,在数值很大的时候,这种取模方式具有明显的优势。

            // Special case for exponent of one
            if (y.equals(ONE))
                return this;
    
            // Special case for base of zero
            if (signum == 0)
                return ZERO;
    View Code

    特判一下指数为1和底数为0的情况。

            int[] base = mag.clone();
            int[] exp = y.mag;
            int[] mod = z.mag;
            int modLen = mod.length;
    
            // Make modLen even. It is conventional to use a cryptographic
            // modulus that is 512, 768, 1024, or 2048 bits, so this code
            // will not normally be executed. However, it is necessary for
            // the correct functioning of the HotSpot intrinsics.
            if ((modLen & 1) != 0) {
                int[] x = new int[modLen + 1];
                System.arraycopy(mod, 0, x, 1, modLen);
                mod = x;
                modLen++;
            }
    View Code

    把类中的东西全拿出来存成数组,然后让modlen为偶数,方便二进制操作,而注释中提到了这也使得jvm可以执行内部函数提高执行效率。

            // Select an appropriate window size
            int wbits = 0;
            int ebits = bitLength(exp, exp.length);
            // if exponent is 65537 (0x10001), use minimum window size
            if ((ebits != 17) || (exp[0] != 65537)) {
                while (ebits > bnExpModThreshTable[wbits]) {
                    wbits++;
                }
            }
    View Code

    设置合适的窗口大小,说实话对于下面这个窗口数组,我只搞明白了为什么最后一位是Integer.MAX_VALUE,因为BigInteger里mag转化成数组之后,每位最大不会超过这个值,或者说,它只有9个十进制位,但是前面几个哨兵的取法我没太明白。

            // Calculate appropriate table size
            int tblmask = 1 << wbits;
    
            // Allocate table for precomputed odd powers of base in Montgomery form
            int[][] table = new int[tblmask][];
            for (int i=0; i < tblmask; i++)
                table[i] = new int[modLen];
    
            // Compute the modular inverse of the least significant 64-bit
            // digit of the modulus
            long n0 = (mod[modLen-1] & LONG_MASK) + ((mod[modLen-2] & LONG_MASK) << 32);
            long inv = -MutableBigInteger.inverseMod64(n0);
    View Code

    先建立好表的大小,tablemask是找到窗口的最高位的位置。然后是把模数打一个64位乘法逆元的表,用来做蒙哥马利域的转换。

            // Convert base to Montgomery form
            int[] a = leftShift(base, base.length, modLen << 5);
    
            MutableBigInteger q = new MutableBigInteger(),
                              a2 = new MutableBigInteger(a),
                              b2 = new MutableBigInteger(mod);
            b2.normalize(); // MutableBigInteger.divide() assumes that its
                            // divisor is in normal form.
    
            MutableBigInteger r= a2.divide(b2, q);
            table[0] = r.toIntArray();
    
            // Pad table[0] with leading zeros so its length is at least modLen
            if (table[0].length < modLen) {
               int offset = modLen - table[0].length;
               int[] t2 = new int[modLen];
               System.arraycopy(table[0], 0, t2, offset, table[0].length);
               table[0] = t2;
            }
    View Code

    这段一些细节没太看明白,前半部分是把底数转到蒙哥马利域。下面是补前导零补到modLen。

            // Set b to the square of the base
            int[] b = montgomerySquare(table[0], mod, modLen, inv, null);
    
            // Set t to high half of b
            int[] t = Arrays.copyOf(b, modLen);
    
            // Fill in the table with odd powers of the base
            for (int i=1; i < tblmask; i++) {
                table[i] = montgomeryMultiply(t, table[i-1], mod, modLen, inv, null);
            }
    
            // Pre load the window that slides over the exponent
            int bitpos = 1 << ((ebits-1) & (32-1));
    
            int buf = 0;
            int elen = exp.length;
            int eIndex = 0;
            for (int i = 0; i <= wbits; i++) {
                buf = (buf << 1) | (((exp[eIndex] & bitpos) != 0)?1:0);
                bitpos >>>= 1;
                if (bitpos == 0) {
                    eIndex++;
                    bitpos = 1 << (32-1);
                    elen--;
                }
            }
    View Code

    对底数的奇数次幂打表,表是在蒙哥马利域下的。

            int multpos = ebits;
    
            // The first iteration, which is hoisted out of the main loop
            ebits--;
            boolean isone = true;
    
            multpos = ebits - wbits;
            while ((buf & 1) == 0) {
                buf >>>= 1;
                multpos++;
            }
    
            int[] mult = table[buf >>> 1];
    
            buf = 0;
            if (multpos == ebits)
                isone = false;
    View Code

    处理第一次迭代。

            // The main loop
            while (true) {
                ebits--;
                // Advance the window
                buf <<= 1;
    
                if (elen != 0) {
                    buf |= ((exp[eIndex] & bitpos) != 0) ? 1 : 0;
                    bitpos >>>= 1;
                    if (bitpos == 0) {
                        eIndex++;
                        bitpos = 1 << (32-1);
                        elen--;
                    }
                }
    
                // Examine the window for pending multiplies
                if ((buf & tblmask) != 0) {
                    multpos = ebits - wbits;
                    while ((buf & 1) == 0) {
                        buf >>>= 1;
                        multpos++;
                    }
                    mult = table[buf >>> 1];
                    buf = 0;
                }
    
                // Perform multiply
                if (ebits == multpos) {
                    if (isone) {
                        b = mult.clone();
                        isone = false;
                    } else {
                        t = b;
                        a = montgomeryMultiply(t, mult, mod, modLen, inv, a);
                        t = a; a = b; b = t;
                    }
                }
    
                // Check if done
                if (ebits == 0)
                    break;
    
                // Square the input
                if (!isone) {
                    t = b;
                    a = montgomerySquare(t, mod, modLen, inv, a);
                    t = a; a = b; b = t;
                }
            }
    
            // Convert result out of Montgomery form and return
            int[] t2 = new int[2*modLen];
            System.arraycopy(b, 0, t2, modLen, modLen);
    
            b = montReduce(t2, mod, modLen, (int)inv);
    
            t2 = Arrays.copyOf(b, modLen);
    
            return new BigInteger(1, t2);
    View Code

    主循环迭代,结束后把结果从蒙哥马利域转换回来返回结果。

    相关证明:

    https://blog.csdn.net/weixin_46395886/article/details/112988136 蒙哥马利模乘

  • 相关阅读:
    SQL学习指南第三篇
    SQL学习指南第二篇
    Rebuilding Roads
    TOJ4244: Sum
    K-th Number
    【模板】后缀数组
    冰水挑战
    旅途
    Monkey and Banana
    Max Sum Plus Plus
  • 原文地址:https://www.cnblogs.com/youchandaisuki/p/15076382.html
Copyright © 2020-2023  润新知