• luoguP4705 玩游戏 分治FFT



    [egin{aligned} Ans(k) &= sum limits_{i = 1}^n sum limits_{j = 1}^m sum limits_{t = 0}^k inom{k}{t} a_i^t b_j^{k - t} \ &= sum limits_{t = 0}^k inom{k}{t} (sum limits_{i = 1}^n a_i^t) (sum limits_{j = 1}^m b_i^{k - t}) \ &= k! * sum limits_{t = 0}^k (frac{sum limits_{i = 1}^n a_i^t}{t!}) (frac{sum limits_{j = 1}^m b_i^{k - t}}{(k - t)!}) \ end{aligned} ]

    右边是一个卷积,只需考虑对(t = 0, 1 ..., n)求出(f(t) = sum limits_{i = 1}^n a_i^t)


    考虑生成函数(OGF)

    [egin{aligned} F(x) &= sum limits_{i = 0}^{infty} f(i) * x^i \ &= sum limits_{i = 1}^{infty} sum limits_{j = 1}^{n} a_j^i \ &= sum limits_{i = 1}^n (sum limits_{j = 0}^{infty} (a_ix)^j) \ &= sum limits_{i = 1}^n frac{1}{1 - a_i x} end{aligned} ]

    那么,现在的问题在于如何求解

    [sum limits_{i = 1}^n frac{1}{1 - a_i x} ]


    考虑分治(FFT)
    一种很好想的思路是先求出(sum limits_{i = 1}^l frac{1}{1 - a_i x})(sum limits_{i = l + 1}^r frac{1}{1 - a_i x})
    它们一定是形如(frac{A}{B})的一个式子,不妨设左边为(frac{A}{B}),右边为(frac{C}{D})
    那么合并之后的形式为(frac{AD + BC}{BD}),然后维护即可

    复杂度(O(n log^2 n))


    可以发现$$In'(a * b) = (In(a) + In(b))' = In'(a) + In'(b)$$
    因此,我们考虑$$In'(frac{1}{1 - a_i x}) = frac{-a_i}{1 - a_i x}$$
    注意不能在(In)中添加常数因子,因此我们只能从这个形式来考虑

    [egin{aligned} G(x) &= sum limits_{i = 1}^n In'(frac{1}{1 - a_i x}) \ &= In ' (prod limits_{i = 1}^n frac{1}{1 - a_i x}) end{aligned} ]

    可以用分治(FFT)求出(G)
    观察数列

    [F(x) = a_i^0 + a_i^1 x^1 + a_i^2 x^2 + a_i^3 x^3 ... ]

    [G(x) = -a_i^1 - a_i^2 x^1 - a_i^3 x^2 + ... ]

    因此,$$-xG(x) + n = F(x)$$

    然后就可以啦


    #include <bits/stdc++.h>
    using namespace std;
    
    #define ri register int
    #define rep(io, st, ed) for(ri io = st; io <= ed; io ++)
    #define drep(io, ed, st) for(ri io = ed; io >= st; io --)
    
    const int sid = 5e5 + 5;
    const int mod = 998244353;
    
    #define gc getchar
    inline int read() {
    	int p = 0, w = 1; char c = gc();
    	while(c > '9' || c < '0') { if(c == '-') w = -1; c = gc(); }
    	while(c >= '0' && c <= '9') p = p * 10 + c - '0', c = gc();
    	return p * w;
    }
    
    inline int Inc(int a, int b) { return (a + b >= mod) ? a + b - mod : a + b; }
    inline int Dec(int a, int b) { return (a - b < 0) ? a - b + mod : a - b; }
    inline int mul(int a, int b) { return 1ll * a * b % mod; }
    inline int fp(int a, int k) {
    	int ret = 1;
    	for( ; k; k >>= 1, a = mul(a, a))
    		if(k & 1) ret = mul(ret, a);
    	return ret;
    }
    	
    int rev[sid], fac[sid], inv[sid], ivf[sid];
    int a[sid], b[sid], ak[sid], bk[sid];
    	
    inline void init(int Mn, int &n, int &lg) {
    	n = 1; lg = 0;
    	while(n < Mn) n <<= 1, lg ++;
    }
    	
    inline void NTT(int *a, int n, int opt) {
    	for(ri i = 0; i < n; i ++) 
    		if(i < rev[i]) swap(a[i], a[rev[i]]);
    	for(ri i = 1; i < n; i <<= 1)
    	for(ri j = 0, g = fp(3, (mod - 1) / (i << 1)); j < n; j += (i << 1))
    	for(ri k = j, G = 1; k < i + j; k ++, G = mul(G, g)) {
    		int x = a[k], y = mul(G, a[i + k]);
    		a[k] = (x + y >= mod) ? x + y - mod : x + y;
    		a[i + k] = (x - y < 0) ? x - y + mod : x - y;
    	}
    	if(opt == -1) {
    		reverse(a + 1, a + n);
    		int ivn = fp(n, mod - 2);
    		for(ri i = 0; i < n; i ++) a[i] = mul(a[i], ivn);
    	}
    }
    	
    int ia[sid], ib[sid];
    inline void Inv(int *a, int *b, int n) {
    	if(n == 1) { b[0] = fp(a[0], mod - 2); return; }
    	Inv(a, b, n >> 1);
    	
    	int N = 1, lg = 0; init(n + n, N, lg);
    	for(ri i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    	for(ri i = 0; i < N; i ++) ia[i] = ib[i] = 0;
    	for(ri i = 0; i < n; i ++) ia[i] = a[i], ib[i] = b[i];
    	
    	NTT(ia, N, 1); NTT(ib, N, 1);
    	for(ri i = 0; i < N; i ++) ia[i] = Dec((ib[i] << 1) % mod, mul(ia[i], mul(ib[i], ib[i])));
    	NTT(ia, N, -1);
    	
    	for(ri i = 0; i < n; i ++) b[i] = ia[i];
    }
    	
    inline void Init_Inv(int n) {
    	inv[0] = inv[1] = 1;
    	for(int i = 2; i <= n; i ++) inv[i] = mul(inv[mod % i], mod - mod / i);
    	fac[0] = fac[1] = 1;
    	for(int i = 2; i <= n; i ++) fac[i] = mul(fac[i - 1], i);
    	ivf[0] = ivf[1] = 1;
    	for(int i = 2; i <= n; i ++) ivf[i] = mul(ivf[i - 1], inv[i]);
    }
    	
    inline void wf(int *a, int *b, int n) { for(ri i = 1; i < n; i ++) b[i - 1] = mul(a[i], i); }
    inline void jf(int *a, int *b, int n) { for(ri i = 1; i < n; i ++) b[i] = mul(a[i - 1], inv[i]);}
    	
    int da[sid], iva[sid];
    inline void In(int *a, int *b, int n) {
    	for(ri i = 0; i < n + n; i ++) da[i] = iva[i] = 0; 
    	Inv(a, iva, n); wf(a, da, n);
    	
    	int N = 1, lg = 0; init(n + n, N, lg);
    	for(ri i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    	
    	NTT(da, N, 1); NTT(iva, N, 1);
    	for(ri i = 0; i < N; i ++) da[i] = mul(da[i], iva[i]);
    	NTT(da, N, -1); jf(da, b, n);
    }
    
    int hb[sid], inb[sid];
    inline void Exp(int *a, int *b, int n) {
    	if(n == 1) { b[0] = 1; return; }
    	Exp(a, b, n >> 1);
    	
    	int N = 1, lg = 0; init(n + n, N, lg);
    	for(ri i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    	
    	for(ri i = 0; i < N; i ++) inb[i] = hb[i] = 0;
    	In(b, inb, n);
    	for(ri i = 0; i < n; i ++) hb[i] = Dec(a[i], inb[i]); hb[0] ++;
    	
    	NTT(inb, N, 1); NTT(hb, N, 1);
    	for(ri i = 0; i < N; i ++) inb[i] = mul(inb[i], hb[i]);
    	NTT(inb, N, -1);
    	
    	for(ri i = 0; i < n; i ++) b[i] = inb[i];
    }
    
    int Ib[sid], F[sid * 2], pa[sid], pb[sid];
    inline void calc(int *a, int *b, int n, int t) {
    	int N = 1, lg = 0;
    	init(max(n, t) + 5, N, lg);
    	for(ri i = 0; i < (N << 1); i ++) F[i] = 0;
    	for(ri i = 0; i < n; i ++) F[2 * i] = 1, F[2 * i + 1] = mod - a[i + 1];
    	for(ri i = n; i < N; i ++) F[2 * i] = 1;
    	for(ri i = 1; i < N; i <<= 1) {
    		for(ri j = 0; j < N; j += (i << 1)) {
    			int M = 1, lg = 0;
    			init((i << 2), M, lg);
    			for(ri k = 0; k < M; k ++) rev[k] = (rev[k >> 1] >> 1) | ((k & 1) << (lg - 1));
    			for(ri k = 0; k < M; k ++) pa[k] = pb[k] = 0;
    			for(ri k = 0; k < (i << 1); k ++) 
    				pa[k] = F[(j << 1) + k], pb[k] = F[(j << 1) + (i << 1) + k];
    			NTT(pa, M, 1); NTT(pb, M, 1);
    			for(ri k = 0; k < M; k ++) pa[k] = mul(pa[k], pb[k]);
    			NTT(pa, M, -1);
    			for(ri k = 0; k < (i << 2); k ++) F[(j << 1) + k] = pa[k];
    		}
    	}
    	for(ri i = 0; i < N; i ++) Ib[i] = 0;
    	In(F, Ib, N); wf(Ib, F, N);
    	
    	b[0] = n;
    	for(ri i = 1; i <= t; i ++) b[i] = mul(mod - F[i - 1], ivf[i]);
    }
    	
    inline void solve(int n, int m, int t) {
    	int N = 1, lg = 0;
    	init(t + t + 5, N, lg);
    	for(ri i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    	
    	NTT(ak, N, 1); NTT(bk, N, 1);
    	for(ri i = 0; i < N; i ++) ak[i] = mul(ak[i], bk[i]);
    	NTT(ak, N, -1);
    		
    	int ivnm = fp(mul(n, m), mod - 2);
    	for(ri i = 1; i <= t; i ++) 
    		printf("%d
    ", mul(mul(ak[i], fac[i]), ivnm));
    }
    	
    int main() {
    	int n = read(), m = read();
    	rep(i, 1, n) a[i] = read();
    	rep(i, 1, m) b[i] = read();
    	Init_Inv(500000);
    	int t = read(); 
    	calc(a, ak, n, t); calc(b, bk, m, t);
    	solve(n, m, t);
    	return 0;
    }
    

    请无视中间的exp

  • 相关阅读:
    Apache2的安装
    JVM(9) 程序编译及代码优化
    Java基础(43)Queue队列
    Java基础(42)AbstractSet类
    OptimalSolution(10)--日常
    OptimalSolution(9)--其他问题(1)
    OptimalSolution(9)--其他问题(2)
    OptimalSolution(8)--位运算
    OptimalSolution(7)--大数据和空间限制
    golang教程汇总
  • 原文地址:https://www.cnblogs.com/reverymoon/p/10187388.html
Copyright © 2020-2023  润新知