• 【知识总结】多项式全家桶(一)(NTT、加减乘除和求逆)


    我这种数学一窍不通的菜鸡终于开始学多项式全家桶了……

    必须要会的前置技能:FFT(不会?戳我:【知识总结】快速傅里叶变换(FFT)

    以下无特殊说明的情况下,多项式的长度指多项式最高次项的次数加(1)

    一、NTT

    跟FFT功能差不多,只是把复数域变成了模域(计算复数系数多项式相乘变成计算在模意义下整数系数多项式相乘)。你看FFT里的单位圆是循环的,模一个质数也是循环的嘛qwq。(n)次单位根(w_n)怎么搞?看这里:【BZOJ3328】PYXFIB(数学)(内含相关证明。只看与原根和单位根相关的内容即可。)

    注意裸的NTT要求模数(p)存在原根并且(p-1)(2)的若干次幂的倍数(这个幂要大于多项式次数(n))。于是通常就会用著名的NTT模数:(998244353=2^{23} imes 7 imes 17+1)

    节约篇幅,代码先不放了。后面所有代码里都有NTT模板……

    二、多项式求逆

    对于(n)次多项式(A),如果有多项式(B)满足(ABequiv 1 mod x^{n+1}),则称(B)(A)在模(x^{n+1})意义下的逆元(和整数逆元差不多)。通常采用倍增的方法求逆元。通常都会规定多项式系数在模(p)的意义下。

    首先,(A)在模(x)的意义下就只有一个常数项,所以此时的逆元(B)也只有一个常数项,就是(A)的常数项模(p)的逆元。

    如果我们知道(B_0)(A)在模(x^{lceilfrac{n}{2} ceil})意义下的逆元,现在要求(B)(A)在模(x^n)意义下的逆元。根据题设,显然有:

    [AB=1mod x^n ]

    很明显,(AB)(1)(n-1)次项系数全是(0),所以模一个(x)的低于(n)次幂也一定是(1)。所以

    [AB_0=AB=1mod x^{lceilfrac{n}{2} ceil} ]

    那么

    [B-B_0=0mod x^{lceilfrac{n}{2} ceil} ]

    两边和模数同时平方:

    [B^2+B_0^2-2BB_0=0mod x^n ]

    两边同时乘(A),得到(别忘了(AB=1mod x^n)):

    [B+AB_0^2-2B_0=0mod x^n ]

    然后移项,得到:

    [B=2B_0-AB_0^2mod x^n ]

    照着这个式子递归算就行了。

    由于后面带余除法的代码包含求逆,所以代码同样略去……

    三、加减乘除

    加减法:直接每项对应相加减。

    乘法:这就是NTT的目的啊喂!

    除法:如果不是带余除法直接乘逆元。下面着重介绍带余除法。

    已知(n-1)次多项式(F)(m-1)次多项式(G),求(n-m)次多项式(Q)和多项式(R)(R)的次数小于(m-1)),满足:

    [F(x)=Q(x)G(x)+R(x) mod x^n ]

    很明显,主要的难点在于式子里有个叫做(R)的嘴子(兔崽子Tzz)。如果能把它搞掉该多好……

    注意到(R)的次数小于(m-1),那么我们把它翻转,末尾补(0),是不是就可以把它模成(0)了?定义(mathrm{Tzz}_{A,n})表示把(A)视作一个长为(n)的多项式(高次项补(0))后翻转的结果。即(mathrm{Tzz}_{A,n}(x)=x^{n-1}A(frac{1}{x})=sumlimits_{i=0}^{n-1}a_ix^{n-i-1})

    (F=QG+R)的每个多项式都代入同一个数,这个多项式也一定是成立的。所以:

    [F(frac{1}{x})=Q(frac{1}{x})G(frac{1}{x})+R(frac{1}{x}) ]

    两边同乘(x^{n-1}),得到:

    [x^{n-1}F(frac{1}{x})=x^{n-m}Q(frac{1}{x})cdot x^{m-1}G(frac{1}{x})+x^{n-1}R(frac{1}{x}) ]

    [mathrm{Tzz}_{F,n}=mathrm{Tzz}_{Q,n-m+1}mathrm{Tzz}_{G,m}+mathrm{Tzz}_{R,n} ]

    现在(mathrm{Tzz}_{R,n})的最高次项是(n-1),但是从常数项到(n-m)次项全是(0)(因为(R)的长度最多就是(m-1))。所以现在如果模(n-m+1),那么(mathrm{Tzz}_{R,n})就是(0)了,而(mathrm{Tzz}_{Q,n-m+1})因为最高次是(n-m)所以不会受到影响。

    于是用(mathrm{Tzz}_{F,n})乘上(mathrm{Tzz}_{G,m})的逆元就是(mathrm{Tzz}_{Q,n-m+1}),翻回去就能得到(Q)

    最后把(Q)代进原式,乘一乘减一减就能算出(R)

    所以这样为什么是对的?(以下“低次项”指翻转后的(n-m)项,“高次项”指翻转后的(m)项)首先在模(x^{n-m+1})意义下肯定能保证低次项是对的(即(mathrm{Tzz}{F,n})(mathrm{Tzz}_{G,m}mathrm{Tzz}_{Q,n-m+1})的前(n-m)项相等)。至于高次项,反正有(mathrm{Tzz}_{R,n})来补锅,所以即使不对也没关系。

    完结撒花。

    下一篇:【知识总结】多项式全家桶(二)(ln和exp)

    代码:洛谷4512

    注意NTT的数组一定要保证多余的元素全部是(0)

    代码开头的#undef是防机惨护身符。

    (我脑子有病啊求原根全是手写的……

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #include <cctype>
    #undef i
    #undef j
    #undef k
    #undef min
    #undef max
    #undef swap
    #undef sort
    #undef for
    #undef while
    #undef if
    #undef true
    #undef false
    #undef printf
    #undef scanf
    #undef getchar
    #undef putchar
    #define _ 0
    using namespace std;
    
    namespace zyt
    {
    	template<typename T>
    	inline bool read(T &x)
    	{
    		char c;
    		bool f = false;
    		x = 0;
    		do
    			c = getchar();
    		while (c != EOF && c != '-' && !isdigit(c));
    		if (c == EOF)
    			return false;
    		if (c == '-')
    			f = true, c = getchar();
    		do
    			x = x * 10 + c - '0', c = getchar();
    		while (isdigit(c));
    		if (f)
    			x = -x;
    		return true;
    	}
    	template<typename T>
    	inline void write(T x)
    	{
    		static char buf[20];
    		char *pos = buf;
    		if (x < 0)
    			putchar('-'), x = -x;
    		do
    			*pos++ = x % 10 + '0';
    		while (x /= 10);
    		while (pos > buf)
    			putchar(*--pos);
    	}
    	typedef long long ll;
    	const int N = 1e5 + 10, LEN = (N << 2), p = 998244353;
    	namespace Polynomial
    	{
    		inline int power(int a, int b)
    		{
    			a %= p, b %= p - 1;
    			int ans = 1;
    			while (b)
    			{
    				if (b & 1)
    					ans = (ll)ans * a % p;
    				a = (ll)a * a % p;
    				b >>= 1;
    			}
    			return ans;
    		}
    		inline int inv(const int a)
    		{
    			return power(a, p - 2);
    		}
    		namespace Primitive_Root
    		{
    			pair<int, int> prime[20];
    			int cnt;
    			void get_prime(int n)
    			{
    				cnt = 0;
    				for (int i = 2; i * i <= n; i++)
    				{
    					if (n % i == 0)
    						prime[cnt++] = make_pair(i, 0);
    					while (n % i == 0)
    						++prime[cnt - 1].second, n /= i;
    				}
    			}
    			int get_g(const int n)
    			{
    				get_prime(n - 1);
    				for (int i = 2; i < n; i++)
    				{
    					bool flag = true;
    					for (int j = 0; j < cnt && flag; j++)
    						flag &= (power(i, (n - 1) / prime[j].first) != 1);
    					if (flag)
    						return i;
    				}
    				return -1;
    			}
    		}
    		int omega[LEN], winv[LEN], rev[LEN];
    		void init(const int n, const int lg2)
    		{
    			static int g = 0;
    			if (!g)
    				g = Primitive_Root::get_g(p);
    			int w = power(g, (p - 1) / n), wi = inv(w);
    			omega[0] = winv[0] = 1;
    			for (int i = 1; i < n; i++)
    			{
    				omega[i] = (ll)omega[i - 1] * w % p;
    				winv[i] = (ll)winv[i - 1] * wi % p;
    			}
    			for (int i = 0; i < n; i++)
    				rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
    		}
    		void ntt(int *a, const int *w, const int n)
    		{
    			for (int i = 0; i < n; i++)
    				if (i < rev[i])
    					swap(a[i], a[rev[i]]);
    			for (int l = 1; l < n; l <<= 1)
    				for (int i = 0; i < n; i += (l << 1))
    					for (int k = 0; k < l; k++)
    					{
    						int tmp = (a[i + k] - (ll)w[n / (l << 1) * k] * a[i + l + k] % p + p) % p;
    						a[i + k] = (a[i + k] + (ll)w[n / (l << 1) * k] * a[i + l + k] % p) % p;
    						a[i + l + k] = tmp;
    					}
    		}
    		void reverse(int *a, const int n)
    		{
    			static int tmp[LEN];
    			memcpy(tmp, a, sizeof(int[n]));
    			for (int i = 0; i < n; i++)
    				a[i] = tmp[n - i - 1];
    		}
    		inline void plus(const int *a, const int *b, int *c, const int n)
    		{
    			for (int i = 0; i < n; i++)
    				c[i] = (a[i] + b[i]) % p;
    		}
    		inline void minus(const int *a, const int *b, int *c, const int n)
    		{
    			for (int i = 0; i < n; i++)
    				c[i] = (a[i] - b[i] + p) % p;
    		}
    		void _inv(const int *a, int *b, const int n)
    		{
    			if (n == 1)
    				b[0] = inv(a[0]);
    			else
    			{
    				static int tmp[LEN];
    				_inv(a, b, (n + 1) >> 1);
    				int m = 1, lg2 = 0;
    				while (m < (n << 1) - 1)
    					m <<= 1, ++lg2;
    				memcpy(tmp, a, sizeof(int[n]));
    				memset(tmp + n, 0, sizeof(int[m - n]));
    				memset(b + ((n + 1) >> 1), 0, sizeof(int[m - ((n + 1) >> 1)]));
    				init(m, lg2);
    				ntt(tmp, omega, m);
    				ntt(b, omega, m);
    				for (int i = 0; i < m; i++)
    					b[i] = (b[i] * 2LL % p - (ll)tmp[i] * b[i] % p * b[i] % p + p) % p;
    				ntt(b, winv, m);
    				int invm = inv(m);
    				for (int i = 0; i < m; i++)
    					b[i] = (ll)b[i] * invm % p;
    				memset(b + n, 0, sizeof(int[m - n]));
    			}
    		}
    		void inv(const int *a, int *b, const int n)
    		{
    			static int tmp[LEN];
    			memcpy(tmp, a, sizeof(int[n]));
    			_inv(tmp, b, n);
    		}
    		void mul(const int *a, const int *b, int *c, const int n)
    		{
    			int m = 1, lg2 = 0;
    			while (m < (n << 1))
    				m <<= 1, ++lg2;
    			static int x[LEN], y[LEN];
    			memcpy(x, a, sizeof(int[n]));
    			memset(x + n, 0, sizeof(int[m - n]));
    			memcpy(y, b, sizeof(int[n]));
    			memset(y + n, 0, sizeof(int[m - n]));
    			init(m, lg2);
    			ntt(x, omega, m);
    			ntt(y, omega, m);
    			for (int i = 0; i < m; i++)
    				x[i] = (ll)x[i] * y[i] % p;
    			ntt(x, winv, m);
    			int invm = inv(m);
    			for (int i = 0; i < m; i++)
    				x[i] = (ll)x[i] * invm % p;
    			memcpy(c, x, sizeof(int[n]));
    		}
    		void div(const int *_F, const int *_G, int *_Q, int *_R, const int n, const int m)
    		{
    			static int F[LEN], G[LEN], invG[LEN], Q[LEN], R[LEN];
    			memcpy(F, _F, sizeof(int[n]));
    			memcpy(G, _G, sizeof(int[m]));
    			reverse(F, n), reverse(G, m);
    			if (m < n - m + 1)
    				memset(G + m, 0, sizeof(int[n - m + 1 - m]));
    			inv(G, invG, n - m + 1);
    			mul(F, invG, Q, n - m + 1);
    			reverse(F, n), reverse(G, m), reverse(Q, n - m + 1);
    			mul(G, Q, G, n);
    			minus(F, G, R, n);
    			memcpy(_Q, Q, sizeof(int[n - m + 1]));
    			memcpy(_R, R, sizeof(int[m]));
    		}
    	}
    	int F[LEN], G[LEN], Q[LEN], R[LEN];
    	int work()
    	{
    		int n, m;
    		read(n), read(m);
    		++n, ++m;
    		for (int i = 0; i < n; i++)
    			read(F[i]);
    		for (int i = 0; i < m; i++)
    			read(G[i]);
    		Polynomial::div(F, G, Q, R, n, m);
    		for (int i = 0; i < n - m + 1; i++)
    			write(Q[i]), putchar(' ');
    		putchar('
    ');
    		for (int i = 0; i < m - 1; i++)
    			write(R[i]), putchar(' ');
    		return (0^_^0);
    	}
    }
    int main()
    {
    	return zyt::work();
    }
    
  • 相关阅读:
    NHibernate 入门必看——NHibernate Made Simple
    ASP.NET 的多线程
    asp.net 禁止用户二次登录(转)
    marquee标记用法及在asp.net中的应用(转)
    解决Visual Studio 2005显示中文乱码(zhuan)
    ms sql 触发器( 转)
    Asp.net 页面导航的几种方法与比较
    ASP.NET1.1(VB):DataGrid中"加入序号列"和"截取定长字符串追加'...
    解决“Internet Explorer 无法打开 Internet站点已终止操作”问题(转)
    ASP.NET 2.0的页面指令集(转)
  • 原文地址:https://www.cnblogs.com/zyt1253679098/p/10226915.html
Copyright © 2020-2023  润新知