• 题解 P5282 【模板】快速阶乘算法


    传送门

    总算是学会了这个算法......


    【前置芝士】

    1. 多项式乘法
    2. 任意模数多项式乘法
    3. 多项式连续点值平移

    前两个用于处理任意模数意义下的多项式乘法;

    第三个用于在未知一个不超过 \((r-l)\) 次的多项式具体形式,但已知其在某连续区间 \([l,r]\)\((r-l+1)\) 个点值时,求出任一与之不交的、长度相同的区间的 \((r-l+1)\) 个点的值。复杂度为 \(O(n\log n)\)


    【分析】

    考虑令 \(s=\lfloor\sqrt n\rfloor\) ,则 \(\displaystyle n!=\prod_{i=1}^n i=\prod_{i=0}^{s-1}\prod_{t=1}^s(is+t)\cdot \prod_{i=s^2+1}^n i\)

    若已求出 \(\displaystyle \prod_{i=0}^{s-1}\prod_{t=1}^s(is+t)\),则后面的那个式子直接暴力点乘,复杂度为 \(o(\sqrt n)\) 的。

    有个比较简单的方法是由分治 FFT 或多项式启发式合并求出 \(s\) 次多项式 \(\displaystyle g_s(x)=\prod_{t=1}^s (x+t)\) ;再由多项式多点求值求出 \(g_s(0\cdot s),g_s(1\cdot s),g_s(2\cdot s),\cdots, g_s((s-1)\cdot s)\) 。这些乘起来即为待求项。

    这样一来,复杂度为 \(O(s\log^2 s)+O(s\log^2 s)=O(\sqrt n\log^2 n)\)


    但考虑到最后实际对答案的贡献仅在若干个点处取到。若这些点所处的区间连续,则可能可用多项式连续点值平移求出答案。

    很幸运的是,我们令 \(h_s(x)=g_s(sx)\),则待求的点值变换为了 \(h_s(0),h_s(1),h_s(2),\cdots, h_s(s-1)\) ,是长度为 \(s\) 的连续区间。这就可以用到多项式连续点值平移了。

    现在,我们可以考虑如何使用多项式连续点值平移凑出这些东西了。

    而初始我们也并不知道 \(h_s(x)\) 的任何一个点值,但是我们知道 \(h_1(x)=x+1\) 的所有点值。

    因此,我们考虑如何倍增得到 \(h_s(x)\) 的那些点值。

    当我们已知 \(h_d(0),h_d(1),\cdots,h_d(d)\) 时,若我们想要得到 \(h_{2d}(0),h_{2d}(1),\cdots,h_{2d}(2d)\)

    对于 \(\displaystyle \forall i\leq d\to h_{2d}(i)=g_{2d}(is)=\prod_{t=1}^{2d}(is+t)=g_d(is)\cdot g_d(is+d)\)

    由于 \(g_d(is+d)\) 在模 \(P\) 意义下,与 \(h_d(d\cdot s^{-1}+i)\) 等价,因此 \(h_{2d}(i)=h_d(i)\cdot h_d(i+d\cdot s^{-1})\)

    求解这些点值时,我们只要将 \(h_d(0),h_d(1),\cdots,h_d(d)\) 这些连续点值平移出 \(h_{2d}(0+d\cdot s^{-1}),h_d(1+d\cdot s^{-1}),\cdots, h_d(d+d\cdot s^{-1})\) ,直接两两点乘即可倍增。

    而对于后面的那几个点,如果我们已知 \(h_d(d+1),h_d(d+2),\cdots, h_d(2d)\) ,就也能通过这个方法求解。

    那怎么知道呢?问就是多项式连续点值平移,用 \(h_d(0),h_d(1),\cdots,h_d(d)\) 直接平移出 \(h_d(d+1),h_d(d+2),\cdots,h_d(2d+1)\) 就可以了(这里因为要求区间不交,会额外平移出 \(h_d(2d+1)\))。

    既然这样,那也别客气了,分那么清楚, 第一次先平移出 \(h_d(d+1),h_d(d+2),\cdots, h_d(2d+1)\) ;拼接到原点值后面之后,第二次再直接用 \(h_d(0),h_d(1),\cdots,h_d(2d+1)\) 直接平移出 \(h_d(0+d\cdot s^{-1}),h_d(1+d\cdot s^{-1}),\cdots,h_d(2d+1+d\cdot s^{-1})\) 。点乘则可以得到我们要求的 \(h_{2d}(0),h_{2d}(1),\cdots,h_{2d}(2d)\) 。复杂度为 \(O(d\log d)+O(2d\log 2d)+O(2d)=O(d\log d)\)


    由于我们可以从 \(h_d\) 的状态推出 \(h_{2d}\) 的状态,我们又已知 \(h_1\) 的状态,待求 \(h_s\) 的状态;于是我们去考虑用类似递归快速幂的迭代方法求出结果。

    那还差的步骤就是由 \(h_d\) 的状态推出 \(h_{d+1}\) 的状态。

    这个简单, 由于 \(\displaystyle h_{d+1}(i)=g_{d+1}(is)=\prod_{t=1}^{d+1}(is+t)=g_d(is)\cdot (is+d+1)=h_d(i)\cdot (is+d+1)\) ,我们直接 \(O(d)\) 暴力跑过去就行了。

    但是由 \(h_d\)\(h_{d+1}\) 时,需要额外知道 \(h_{d+1}(d+1)\) 的信息,它需要用到 \(h_d(d+1)\) 的信息。

    但这个无伤大雅,因为 \(h_d\) 推出 \(h_{d+1}\) 的状态一定发生在 \(h_{d\over 2}\) 推出 \(h_d\) 之后。而像上面说的,正好由于区间不交,恰好在 \(h_d\) 推出 \(h_{2d}\) 时,把 \(h_{2d}(2d+1)\) 的给顺便算了。

    那就好办了,如果接着要 \(h_{2d}\)\(h_{2d+1}\) ,就把这个正好算了;如果不要,就顺手丢了。于是这一步的复杂度是严格 \(O(d)\) 的。


    综上,我们由已知的 \(h_1\) 状态,倍增地推出 \(h_s\) 的状态。最后累乘 \(h_s(0),\cdots, h_s(s-1)\) ,再乘上那些尾巴,就可以算出总答案了。

    复杂度 \(\displaystyle T(s)=T({s\over 2})+O(s\log s)+O(s)\) 得到 \(T(s)=O(s\log s)=O(\sqrt n\log n)\) 。(原本还要加上一个 \(o(\sqrt n)\) ,但是作为低阶项舍去了。)


    【核心代码】

    poly c, d;
    inline int get_fac(int n) {
    	int pos=curStk, s=sqrt(n)+1e-6;
    	vir invs=Inv[s];
    	for(int i=s;i>1;i>>=1) Stk[++curStk]=i;//手动压栈
    
    	c.resize(2); c[0]=1; c[1]=s+1;//初始化 h_1 的状态
    	for(int l=Stk[curStk]; curStk>pos; l=Stk[--curStk]) {
    		ValueTrans(c, d, l>>1, (l>>1)+1);//点值平移出额外状态
    		c.resize(2*sz(c));
    		for(int i=0; i<sz(d); ++i) c[sz(d)+i]=d[i];
    		ValueTrans(c, d, sz(c)-1, invs*vir(l>>1));//点值平移出点乘状态
    		for(int i=0; i<sz(c); ++i) c[i]=c[i]*d[i];
    		if(l&1) {
    			for(int i=0; i<=l; ++i) c[i]=c[i]*vir(i*s+l);//h_l 转为 h_{l+1}
    		}
    		else c.resize(l+1);
    	}
    
    	vir res=1;
    	for(int i=0; i<s; ++i) res=res*c[i];//累乘
    	for(int i=s*s+1; i<=n; ++i) res=res*vir(i);//处理尾巴
    	return res;
    }
    

    【拓展】

    你以为这就完了?

    我发现我的 代码 跑得飞快,看到连时限 4s 的十分之一都没跑到,我不禁“恶向胆边生”。用这个板子改了改,测试了一下阶乘求和的结果。

    做法是求出 \(\displaystyle \sum_{i=1}^n i!\)\(\displaystyle \sum_{i=1}^{n-1} i!\),做差算出答案。

    那这个问题怎么用点值平移处理呢?

    我们记 \(\displaystyle S_n=\sum_{i=1}^n i!\) ,考虑它的转移矩阵:

    \(\begin{aligned} \begin{pmatrix} (n+1)! \\S_n \end{pmatrix}&= \begin{pmatrix} n+1&0 \\1&1 \end{pmatrix}\cdot \begin{pmatrix} n! \\S_{n-1} \end{pmatrix} \\&=\prod_{i=1}^n \begin{pmatrix} i+1&0 \\1&1 \end{pmatrix}\cdot \begin{pmatrix} 1 \\0 \end{pmatrix} \end{aligned}\)

    同理,我们考虑 \(\displaystyle g_d(x)=\prod_{i=1}^d\begin{pmatrix} x+i+1&0 \\1&1 \end{pmatrix},h_d(i)=g_d(is)\)

    手玩一下会发现,这个矩阵这个矩阵的乘法也有特点:

    \(\begin{pmatrix} a_1&\ \\b_1&c_1 \end{pmatrix}\cdot\begin{pmatrix} a_2&\ \\b_2&c_2 \end{pmatrix}=\begin{pmatrix} a_1a_2&\ \\b_1a_2+c_1b_2&c_1c_2 \end{pmatrix}\)

    也就是说,这个矩阵的乘积永远只有三个地方有值,并且右下角恒为 \(1\) 。于是我们维护 \(h_d(x)=\begin{pmatrix}mc_{d,0}(x)&\ \\mc_{d,1}(x)&1\end{pmatrix}\)

    由原来维护一个点值变成了维护两个点值 而已

    由于此时为矩阵乘法 \(h_{2d}(i)=h_d(i+d\cdot s^{-1})\cdot h_d(i)\) 形式不会发生变化,但一定要注意矩阵的左右乘不等价。

    发生变化的是 \(h_d\) 转移到 \(h_{d+1}\) 时,此时有:

    \(\begin{aligned} &h_{d+1}(i) \\=&\begin{pmatrix}mc_{d+1,0}(i)&\ \\mc_{d+1,1}(i)&1\end{pmatrix} \\=&\begin{pmatrix}i+d+2&0\\1&1\end{pmatrix}h_d(i) \\=&\begin{pmatrix}i+d+2&0\\1&1\end{pmatrix}\begin{pmatrix}mc_{d,0}(x)&\ \\mc_{d,1}(x)&1\end{pmatrix} \\=&\begin{pmatrix}(i+d+2)mc_{d,0}(x)&\ \\mc_{d,0}+mc_{d,1}(x)&1\end{pmatrix} \end{aligned}\)

    修改此处即可。

    poly mc[2], md[2];
    inline int get_sumfac(int n) {
    	int pos=curStk, s=sqrt(n)+1e-6;
    	vir invs=Inv[s];
    	for(int i=s;i>1;i>>=1) Stk[++curStk]=i;
    
    	mc[0].resize(2); mc[0][0]=2; mc[0][1]=s+2;
    	mc[1].resize(2); mc[1][0]=mc[1][1]=1;
    	for(int l=Stk[curStk]; curStk>pos; l=Stk[--curStk]) {
    		for(int t=0; t<2; ++t){
    			poly &c = mc[t], &d = md[t];
    			ValueTrans(c, d, l>>1, (l>>1)+1);
    			c.resize(2*sz(c));
    			for(int i=0; i<sz(d); ++i) c[sz(d)+i]=d[i];
    			ValueTrans(c, d, sz(c)-1, invs*vir(l>>1));
    		}
    		for(int i=0; i<sz(mc[0]); ++i) {
    			mc[1][i]=mc[0][i]*md[1][i]+mc[1][i];
    			mc[0][i]=mc[0][i]*md[0][i];
    		}
    		if(l&1) {
    			for(int i=0; i<=l; ++i) {
    				mc[1][i]=mc[0][i]+mc[1][i];
    				mc[0][i]=mc[0][i]*vir(i*s+l+1);
    			}
    		}
    		else {
    			mc[0].resize(l+1);
    			mc[1].resize(l+1);
    		}
    	}
    
    	vir res0=1, res1=0;
    	for(int i=0; i<s; ++i){
    		res1=res0*mc[1][i]+res1;
    		res0=res0*mc[0][i];
    	}
    	for(int i=s*s+1; i<=n; ++i){
    		res1=res1+res0;
    		res0=res0*vir(i+1);
    	}
    	return res1;
    }
    

    结果 还是没跑到一半的时间。

  • 相关阅读:
    格式化输出函数(1): Format
    ini 文件操作记要(2): 使用 TMemIniFile
    文本文件读写
    格式化输出函数(3): FormatFloat
    Delphi 中的哈希表(2): TStringHash
    格式化输出函数(2): FormatDateTime
    Delphi 中的哈希表(1): THashedStringList
    调用系统关于对话框
    在Ubuntu上安装Thrift并调通一个例子
    rpm2html/rpmfind
  • 原文地址:https://www.cnblogs.com/JustinRochester/p/15846594.html
Copyright © 2020-2023  润新知