大致题意:
给你N个整数和M个整数,问这M个数中,有几个数可以表达成那N个整数中一个或者两个整数的和。
分析:
算是半个裸的FFT。FFT可以用来在nlongn时间内求高精度乘法,我们先模拟一下乘法。
A4A3A2A1A0*B4B3B2B1B0 Ai,Bj表示位数,结果保存在Ck中
4 3 2 1 0(下标)
A4 A3 A2 A1 A0
B4 B3 B2 B1 B0
先不考虑进位
那么C0=A0*B0
C1=A0*B1+A1*B0
Ck=sum(Ai*Bj) (i+j=k)
我们现在看题目是求两个数或者一个数的和是否能在M中匹配到。与上式对比,发现了共同点,i,j的和的下标对应的值即Ck,也就是说我们可以把出现的数字当做位置,出现的位置赋值为1,没有为0,第0位为1,因为可以加0嘛。
样例如下
5 4 3 2 1 0(下标)
1 0 1 0 1 1
X 1 0 1 0 1 1
= 1 0 2 0 3 2 2 2 1 2 1
A 9 8 7 6 5 4 3 2 1 0(下标)
结果不需要再进位,只要k下标对应的值大于0就符合条件。
#include <bits/stdc++.h> using namespace std; const double PI = acos(-1.0); struct Complex { double x, y; Complex(double _x = 0.0, double _y = 0.0) { x = _x; y = _y; } Complex operator - (const Complex &b)const { return Complex(x-b.x, y-b.y); } Complex operator + (const Complex &b)const { return Complex(x+b.x, y+b.y); } Complex operator * (const Complex &b)const { return Complex(x*b.x-y*b.y, x*b.y+y*b.x); } }; void change(Complex y[], int len) { int i, j, k; for(i = 1, j = len/2; i < len-1; i++) { if (i < j) swap(y[i], y[j]); k = len/2; while(j >= k) { j -= k; k /= 2; } if (j < k) j += k; } } void fft(Complex y[], int len, int on) { change(y, len); for(int h = 2; h <= len; h <<= 1) { Complex wn(cos(-on*2*PI/h), sin(-on*2*PI/h)); for(int j = 0; j < len; j += h) { Complex w(1, 0); for(int k = j; k < j+h/2; k++) { Complex u = y[k]; Complex t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w*wn; } } } if (on == -1) for(int i = 0; i < len; i++) y[i].x /= len; } const int maxn=4*200010; Complex x1[maxn],x2[maxn]; int a[maxn/4]; int sum[maxn]; int main() { int N,M; while(~scanf("%d",&N)) { memset(a,0,sizeof(a)); int len1=0; for(int i=0; i<N; i++) { int tmp; scanf("%d",&tmp); a[tmp]=1; len1=max(len1,tmp); } len1++; int len=1; while(len < len1*2) len<<=1; a[0]=1; for(int i=0; i<len1; i++) x1[i]=Complex(a[i],0); for(int i=0; i<len1; i++) x2[i]=Complex(a[i],0); for(int i=len1; i<len; i++) x2[i]=Complex(0,0); fft(x1,len,1); fft(x2,len,1); for(int i=0; i<len; i++) x1[i]=x1[i]*x2[i]; fft(x1,len,-1); for(int i=0; i<len; i++) sum[i]=(int)(x1[i].x+0.5); scanf("%d",&M); len=2*len1-1; int cnt=0; for(int i=0;i<M;i++) { int t; scanf("%d",&t); if(sum[t]>0) cnt++; } printf("%d ",cnt); } return 0; }