• [洛谷P4245]【模板】任意模数NTT


    题目大意:给你两个多项式$f(x)$和$g(x)$以及一个模数$p(pleqslant10^9)$,求$f*gpmod p$

    题解:任意模数$NTT$,最大的数为$p^2 imesmax{n,m}leqslant10^{23}$,所以一般选$3$个模数即可,求出这三个模数下的答案,然后中国剩余定理即可。

    假设这一位的答案是$x$,三个模数分别为$A,B,C$,那么:

    $$
    xequiv x_1pmod{A}\
    xequiv x_2pmod{B}\
    xequiv x_3pmod{C}
    $$

    先把前两个合并:

    $$
    x_1+k_1A=x_2+k_2B\
    x_1+k_1Aequiv x_2pmod{B}\
    k_1equiv frac{x_2-x_1}Apmod{B}
    $$

    于是求出了$k_1$,也就求出了$xequiv x_1+k_1Apmod{AB}$,记$x_4=x_1+k_1A$

    $$
    x_4+k_4AB=x_3+k_3C\
    x_4+k_4ABequiv x_3pmod{C}\
    k_4equiv dfrac{x_3-x_4}{AB}pmod{C}
    $$

    求出了$k_4$,$xequiv x_4+k_4ABpmod{ABC}$,因为$x<ABC$,所以$x=x_4+k_4AB$

    卡点:$Wn$数组开小,中国剩余定理写错

    C++ Code:

    #include <algorithm>
    #include <cstdio>
    #include <cstring>
    int mod;
    namespace Math {
    	inline int pw(int base, int p, const int mod) {
    		static int res;
    		for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod;
    		return res;
    	}
    	inline int inv(int x, const int mod) { return pw(x, mod - 2, mod); }
    }
    
    const int mod1 = 998244353, mod2 = 1004535809, mod3 = 469762049, G = 3;
    const long long mod_1_2 = static_cast<long long> (mod1) * mod2;
    const int inv_1 = Math::inv(mod1, mod2), inv_2 = Math::inv(mod_1_2 % mod3, mod3);
    struct Int {
    	int A, B, C;
    	explicit inline Int() { }
    	explicit inline Int(int __num) : A(__num), B(__num), C(__num) { }
    	explicit inline Int(int __A, int __B, int __C) : A(__A), B(__B), C(__C) { }
    	static inline Int reduce(const Int &x) {
    		return Int(x.A + (x.A >> 31 & mod1), x.B + (x.B >> 31 & mod2), x.C + (x.C >> 31 & mod3));
    	}
    	inline friend Int operator + (const Int &lhs, const Int &rhs) {
    		return reduce(Int(lhs.A + rhs.A - mod1, lhs.B + rhs.B - mod2, lhs.C + rhs.C - mod3));
    	}
    	inline friend Int operator - (const Int &lhs, const Int &rhs) {
    		return reduce(Int(lhs.A - rhs.A, lhs.B - rhs.B, lhs.C - rhs.C));
    	}
    	inline friend Int operator * (const Int &lhs, const Int &rhs) {
    		return Int(static_cast<long long> (lhs.A) * rhs.A % mod1, static_cast<long long> (lhs.B) * rhs.B % mod2, static_cast<long long> (lhs.C) * rhs.C % mod3);
    	}
    	inline int get() {
    		long long x = static_cast<long long> (B - A + mod2) % mod2 * inv_1 % mod2 * mod1 + A;
    		return (static_cast<long long> (C - x % mod3 + mod3) % mod3 * inv_2 % mod3 * (mod_1_2 % mod) % mod + x) % mod;
    	}
    } ;
    
    #define maxn 131072
    
    namespace Poly {
    #define N (maxn << 1)
    	int lim, s, rev[N];
    	Int Wn[N | 1];
    	inline void init(int n) {
    		s = -1, lim = 1; while (lim < n) lim <<= 1, ++s;
    		for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
    		const Int t(Math::pw(G, (mod1 - 1) / lim, mod1), Math::pw(G, (mod2 - 1) / lim, mod2), Math::pw(G, (mod3 - 1) / lim, mod3));
    		*Wn = Int(1); for (register Int *i = Wn; i != Wn + lim; ++i) *(i + 1) = *i * t;
    	}
    	inline void NTT(Int *A, const int op = 1) {
    		for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
    		for (register int mid = 1; mid < lim; mid <<= 1) {
    			const int t = lim / mid >> 1;
    			for (register int i = 0; i < lim; i += mid << 1) {
    				for (register int j = 0; j < mid; ++j) {
    					const Int W = op ? Wn[t * j] : Wn[lim - t * j];
    					const Int X = A[i + j], Y = A[i + j + mid] * W;
    					A[i + j] = X + Y, A[i + j + mid] = X - Y;
    				}
    			}
    		}
    		if (!op) {
    			const Int ilim(Math::inv(lim, mod1), Math::inv(lim, mod2), Math::inv(lim, mod3));
    			for (register Int *i = A; i != A + lim; ++i) *i = (*i) * ilim;
    		}
    	}
    #undef N
    }
    
    int n, m;
    Int A[maxn << 1], B[maxn << 1];
    int main() {
    	scanf("%d%d%d", &n, &m, &mod); ++n, ++m;
    	for (int i = 0, x; i < n; ++i) scanf("%d", &x), A[i] = Int(x % mod);
    	for (int i = 0, x; i < m; ++i) scanf("%d", &x), B[i] = Int(x % mod);
    	Poly::init(n + m);
    	Poly::NTT(A), Poly::NTT(B);
    	for (int i = 0; i < Poly::lim; ++i) A[i] = A[i] * B[i];
    	Poly::NTT(A, 0);
    	for (int i = 0; i < n + m - 1; ++i) {
    		printf("%d", A[i].get());
    		putchar(i == n + m - 2 ? '
    ' : ' ');
    	}
    	return 0;
    }
    

      

  • 相关阅读:
    java基础知识回顾之final
    基础知识《十四》Java异常的栈轨迹fillInStackTrace和printStackTrace的用法
    基础知识《六》---Java集合类: Set、List、Map、Queue使用场景梳理
    基础知识《五》---Java多线程的常见陷阱
    基础知识《四》---Java多线程学习总结
    《转》如何选择合适的服务器托管商
    基础知识《三》java修饰符
    基础知识《零》---Java程序运行机制及运行过程
    应用 JD-Eclipse 插件实现 RFT 中 .class 文件的反向编译
    DOS命令符基本操作
  • 原文地址:https://www.cnblogs.com/Memory-of-winter/p/10223844.html
Copyright © 2020-2023  润新知