变一下题目的式子,变成 $A[i]+A[k]=2A[j],i<j,k>j$
发现 $A[i]$ 的值域不大,考虑移动指针 $pos$ 并维护 $cntl[],cntr[]$ 分别表示 $pos$ 左右两边各种值的数的数量
设 $ans[i]$ 表示当前 $pos$ 左右两边各取一个数,相加为 $i$ 的方案数,那么对于每一个 $j=pos$,贡献即为 $ans[2A[pos]]$
同时 $ans[i]=sum_{j=0}^{i}cntl[j]cntr[i-j]$,其实就是一个卷积的形式,可以用 $FFT$ 优化,但是 $ans$ 会随着 $pos$ 改变一起改变
如果用 $FFT$ 未免大材小用了,对于每一个位置 $pos$ 我们只想知道 $ans[2A[pos]]$,其他的都无关,$FFT$ 好像只会白白增加复杂度
考虑分块,设块大小为 $T$
对于块内的情况可以直接暴力算,复杂度 $nT$
对于 $i$ 在块左边,$j$ 在块内,$k$ 在块右边的情况,直接搞 $FFT$
这样一次 $FFT$ 可以解决 块大小的答案,设值域为 $S$,那么这样复杂度就是 $O(n/TSlog_S)$
发现 $T$ 取 $sqrt(n)$ 时最优,但是因为 $FFT$ 常数大,所以可以适当增大 $T$ 来减少 $FFT$ 的次数
具体看代码就行,不难看懂
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> using namespace std; typedef long long ll; typedef double db; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } const int N=1e5+7; const db pi=acos(-1.0); struct CP { db x,y; CP (db xx=0,db yy=0) { x=xx,y=yy; } inline CP operator + (const CP &tmp) const { return CP(x+tmp.x,y+tmp.y); } inline CP operator - (const CP &tmp) const { return CP(x-tmp.x,y-tmp.y); } inline CP operator * (const CP &tmp) const { return CP(x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x); } }A[N],B[N]; int n,p[N],T,d[N],cntl[N],cntr[N]; ll ans; void FFT(CP *A,int len,int type) { for(int i=0;i<len;i++) if(i<p[i]) swap(A[i],A[p[i]]); for(int mid=1;mid<len;mid<<=1) { CP wn(cos(pi/mid),type*sin(pi/mid)); for(int R=mid<<1,j=0;j<len;j+=R) { CP w(1,0); for(int k=0;k<mid;k++,w=w*wn) { CP x=A[j+k],y=w*A[j+mid+k]; A[j+k]=x+y; A[j+mid+k]=x-y; } } } } int main() { n=read(); T=min(n,6*(int)sqrt(n)); for(int i=1;i<=n;i++) d[i]=read(),cntr[d[i]]++; for(int i=1;i<n;i+=T) { int r=min(i+T-1,n); for(int j=i;j<=r;j++) cntr[d[j]]--; for(int j=i;j<=r;j++) { for(int k=j+1;k<=r;k++) { int t=(d[j]<<1)-d[k]; if(t>=0) ans+=cntl[t]; t=(d[k]<<1)-d[j]; if(t>=0) ans+=cntr[t]; } cntl[d[j]]++; } } for(int i=1;i<n;i+=T) { int mx=0,r=min(i+T-1,n); for(int j=1;j<i;j++) A[d[j]].x++,mx=max(mx,d[j]); for(int j=r+1;j<=n;j++) B[d[j]].x++,mx=max(mx,d[j]); int len=1,tot=0; while(len<=mx*2) len<<=1,tot++; for(int j=0;j<len;j++) p[j]=(p[j>>1]>>1) | ((j&1)<<(tot-1)); FFT(A,len,1); FFT(B,len,1); for(int j=0;j<=len;j++) A[j]=A[j]*B[j]; FFT(A,len,-1); for(int j=i;j<=r;j++) ans+=ll(A[d[j]<<1].x/len+0.5); for(int j=0;j<=len;j++) A[j]=B[j]=CP(0,0); } printf("%lld ",ans); return 0; }