题目描述
给出两个 $n$ 位10进制数x和y,求x*y(详见 洛谷P1919)
分析
假设已经学会了FFT/NTT。
高精度乘法只是多项式乘法的特殊情况,相当于$x=10$ 时。
例如n=3,求123*111
$$123 = x^2 + 2x + 3$$
$$111 = x^2 + x +1$$
$$egin{aligned}123 * 111 &= (x^2 + 2x + 3)(x^2 + x +1)\ &= x^4 + 3x^3 + 6x^2 + 5x + 3\ &= 13653end{aligned}$$
代码:
#include<bits/stdc++.h> #define rg register using namespace std; typedef long long ll; const int mod=998244353,g=3; const int maxn = 6e4 + 10; inline int qpow(int x,int k) { int ans=1; while(k) { if(k&1) ans=(ll)ans*x%mod; x=(ll)x*x%mod,k>>=1; } return ans; } inline int module(int x,int y) { x+=y; if(x>=mod) x-=mod; return x; } int rev[4*maxn]; inline void NTT(int*t,int lim,int type) { for(rg int i=0;i<lim;++i) if(i<rev[i]) swap(t[i],t[rev[i]]); for(rg int i=1;i<lim;i<<=1) { int gn=qpow(g,(mod-1)/(i<<1)); if(type==-1) gn=qpow(gn,mod-2); for(rg int j=0;j<lim;j+=(i<<1)) { int gi=1; for(rg int k=0;k<i;++k,gi=(ll)gi*gn%mod) { int x=t[j+k],y=(ll)gi*t[j+i+k]%mod; t[j+k]=module(x,y); t[j+i+k]=module(x,mod-y); } } } if(type==-1) { int inv=qpow(lim,mod-2); for(rg int i=0;i<lim;++i) t[i]=(ll)t[i]*inv%mod; } } int X[4*maxn],Y[4*maxn]; inline void mul(int*x, int*y, int n, int m) { memset(X,0,sizeof(X)); memset(Y,0,sizeof(Y)); int lim = 1, L = 0; //L=0必须写,局部变量默认值很可能不是0 while(lim <= n + m) lim <<= 1, L++; //lim为大于(n+m)的2的幂,所以最多需要4倍空间 for(int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1)); for(rg int i=0;i<lim;++i) X[i]=x[i],Y[i]=y[i]; NTT(X,lim,1); NTT(Y,lim,1); for(rg int i=0;i<lim;++i) X[i]=(ll)X[i]*Y[i]%mod; NTT(X,lim,-1); for(rg int i=0;i<lim;++i) x[i]=X[i]; } int n; int a[4*maxn], b[4*maxn]; char s[maxn]; int main() { scanf("%d", &n); scanf("%s", s); for(int i = 0;i < n;i++) a[i] = s[n-1-i] - '0'; scanf("%s", s); for(int i = 0;i < n;i++) b[i] = s[n-1-i] - '0'; mul(a, b, n, n); // for(int i = 0;i < 2*n;i++) printf("%d ", a[i]); // printf(" "); int tmp = 0; //进位 for(int i = 0;i < 2*n;i++) // { a[i] = a[i] + tmp; tmp = a[i] / 10; a[i] = a[i] % 10; } // for(int i = 0;i < 2*n;i++) printf("%d ", a[i]); // printf(" "); bool flag = true; for(int i = 2*n;i >= 0;i--) //逆序输出,去掉前导零 { if(flag && a[i] == 0) continue; printf("%d", a[i]); flag = false; } return 0; }