题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3160
求出关于一个位置有多少对对称字母,如果 i 位置有 f[i] 对,对答案的贡献是 2^f[i] - 1;
然后减去连续的,用 manachar 求出回文长度,每个位置作为边界都是一种不合法情况;
求对称,首先把字符串中间穿插字符 '$',于是字符串的长度变成2倍;
考虑一对字母 s[x],s[y],如果 s[x] = s[y],其对称中心是 (x+y)/2;
放在加入字符后的字符串中,对称中心就是 x+y;
所以可以看出卷积了:f[i] = ∑(0<=j<=i) (s[j]==s[i-j]),其中 i 视为新字符串中的位置,j 和 i-j 视为原字符串中的位置;
注意卷积和 manachar 算的个数都要包括自己成对,否则判断挺麻烦...
这里卷积的两个多项式其实是一样的,所以只要用 FFT 算出一个,然后自己乘起来即可;
做下一步的时候注意清空,别忘了清空 n~lim 部分的值;
处理 bin 的边界是 n 而非 n-1,因为最多可能有 n 对。
(学习了 manachar 的简洁写法)
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> using namespace std; typedef double db; int const xn=(1<<19),mod=1e9+7; db const Pi=acos(-1.0); int n,rev[xn],lim=1,l,len[xn],bin[xn],c[xn]; char ch[xn]; struct com{db x,y;}a[xn],b[xn],aa[xn]; com operator + (com a,com b){return (com){a.x+b.x,a.y+b.y};} com operator - (com a,com b){return (com){a.x-b.x,a.y-b.y};} com operator * (com a,com b){return (com){a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y};} int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;} void fft(com *a,int tp) { for(int i=0;i<lim;i++) if(i<rev[i])swap(a[i],a[rev[i]]); for(int mid=1;mid<lim;mid<<=1) { com wn=(com){cos(Pi/mid),tp*sin(Pi/mid)}; for(int j=0,len=(mid<<1);j<lim;j+=len) { com w=(com){1,0}; for(int k=0;k<mid;k++,w=w*wn) { com x=a[j+k],y=w*a[j+mid+k]; a[j+k]=x+y; a[j+mid+k]=x-y; } } } } void solve() { for(int i=0;i<n;i++)a[i].x=(ch[i]=='a'); fft(a,1); for(int i=0;i<lim;i++)b[i]=a[i]*a[i]; for(int i=0;i<n;i++)a[i].x=(ch[i]=='b'),a[i].y=0;//y=0 for(int i=n;i<lim;i++)a[i].x=0,a[i].y=0;//!! fft(a,1); for(int i=0;i<lim;i++)b[i]=b[i]+a[i]*a[i]; fft(b,-1); for(int i=0;i<n+n;i++)c[i]=(c[i]+(int)(b[i].x/lim+0.5))%mod; } char s[xn]; int manachar()//+i self { int mx=0,id=0,ret=0; s[0]='$'; for(int i=1;i<=n+n;i++) if(i%2==0)s[i]='$'; else s[i]=ch[i>>1]; for(int i=1;i<=n+n;i++) { if(i<mx)len[i]=min(mx-i,len[id*2-i]); while(i-len[i]>=0&&i+len[i]<=n+n&&s[i-len[i]]==s[i+len[i]])len[i]++; if(i+len[i]>mx)mx=i+len[i],id=i; ret=upt(ret+len[i]/2); } return ret; } int main() { scanf("%s",ch); n=strlen(ch); while(lim<=n+n)lim<<=1,l++;// for(int i=0;i<lim;i++) rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1))); bin[0]=1; for(int i=1;i<=n;i++)bin[i]=upt(bin[i-1]+bin[i-1]); solve(); int ans=0; for(int i=0;i<n+n;i++)ans=upt(ans+bin[(c[i]+1)>>1]-1);//+1 -1 printf("%d ",upt(ans-manachar())); return 0; }