题解:
认为大爷讲的最好。
代码:
#include<cmath> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define ll long long #define N 200050 #define MOD 1000000007 const double Pi = acos(-1.0); ll fastpow(ll x,int y) { ll ret = 1; while(y) { if(y&1)ret=ret*x%MOD; x=x*x%MOD; y>>=1; } return ret; } struct cp { double x,y; cp(){} cp(double x,double y):x(x),y(y){} }; cp operator + (cp &a,cp &b) { return cp(a.x+b.x,a.y+b.y); } cp operator - (cp &a,cp &b) { return cp(a.x-b.x,a.y-b.y); } cp operator * (cp &a,cp &b) { return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x); } char ch[N],s[4*N]; int to[2*N]; void fft(cp *a,int len,int k) { for(int i=0;i<len;i++) if(i<to[i])swap(a[i],a[to[i]]); for(int i=1;i<len;i<<=1) { cp w0(cos(Pi/i),k*sin(Pi/i)); for(int j=0;j<len;j+=(i<<1)) { cp w(1,0); for(int o=0;o<i;o++,w=w*w0) { cp w1 = a[j+o],w2 = a[j+o+i]*w; a[j+o] = w1+w2; a[j+o+i] = w1-w2; } } } } cp a[2*N],c[2*N]; int len,lim=1,l; ll ans[2*N],sum; void sol() { fft(a,lim,1); for(int i=0;i<lim;i++)c[i]=a[i]*a[i]; fft(c,lim,-1); for(int i=0;i<lim;i++)ans[i]+=(ll)(c[i].x/lim+0.5); } int sl; void get_s() { s[0]='!',s[sl=1]='#'; for(int i=0;i<len;i++) { s[++sl]=ch[i]; s[++sl]='#'; } s[sl+1] = '@'; } int rp[2*N]; void manacher() { int mid = 0,mx = 0; for(int i=1;i<=sl;i++) { if(i<=mx)rp[i] = min(rp[2*mid-i],mx-i+1); else rp[i] = 1; while(s[i+rp[i]]==s[i-rp[i]])rp[i]++; if(i+rp[i]-1>mx)mx=i+rp[i]-1,mid=i; sum-=rp[i]/2; } } int main() { scanf("%s",ch); len = strlen(ch); while(lim<2*len)lim<<=1,l++; for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1))); for(int i=0;i<len;i++)a[i].x=(ch[i]=='a'); sol(); for(int i=0;i<lim;i++)a[i].x=a[i].y=0; for(int i=0;i<len;i++)a[i].x=(ch[i]=='b'); sol(); for(int i=0;i<lim;i++)ans[i]>>=1; for(int i=0;i<2*len;i+=2)ans[i]++; get_s(); manacher(); for(int i=0;i<lim;i++)(sum+=fastpow(2,ans[i])-1ll)%=MOD; printf("%lld ",(sum+MOD)%MOD); return 0; }