对拍没错。。莫名wa了
利用容斥求每个串的重复子串
其实就是找到每个元素能扩展到的最大元素
即(rr-i)*(i-lr)*(w[i]-kk) 就可以了
然后处理这个先离散化再搞
另外是x y要清空
#include <cstring> #include <iostream> #include <cmath> #include <algorithm> #define ll long long using namespace std; const ll N=2e5+1000; const double pi=acos(-1.0); ll height[N],rank[N],sa[N],c[N],x[N],y[N],a[N],n,m,kk; char s[N]; void asa(ll n,ll m) { memset(x,0,sizeof(x)); memset(y,0,sizeof(y)); ll p=0,f=0; for (ll i=1;i<=m;i++) c[i]=0; for (ll i=1;i<=n;i++) c[x[i]=a[i]]++; for (ll i=1;i<=m;i++) c[i]+=c[i-1]; for (ll i=n;i;i--) sa[c[x[i]]--]=i; for(ll i=1;i<=n&&p<=n;i<<=1) { p=0; for (ll j=n-i+1;j<=n;j++) y[++p]=j; for (ll j=1;j<=n;j++) if (sa[j]>i) y[++p]=sa[j]-i; for (ll j=1;j<=m;j++) c[j]=0; for (ll j=1;j<=n;j++) c[x[y[j]]]++; for (ll j=1;j<=m;j++) c[j]+=c[j-1]; for (ll j=n;j;j--) sa[c[x[y[j]]]--]=y[j]; swap(x,y); x[sa[1]]=1; p=2; for (ll j=2;j<=n;j++) x[sa[j]]=y[sa[j]]==y[sa[j-1]]&&y[sa[j]+i]==y[sa[j-1]+i] ?p-1:p++; m=p; } for (ll i=1;i<=n;i++) rank[sa[i]]=i; for (ll i=1;i<=n;i++) { ll j=sa[rank[i]-1]; if (f) f--; while (a[i+f]==a[j+f]) f++; height[rank[i]]=f; } } struct re{ ll a,b; }; bool cmp(re x,re y) { return(x.a<y.a); } ll lr[N],rr[N]; ll get_ans() { re a[N],b[N]; for (ll i=1;i<=n;i++) a[i].a=height[i],a[i].b=i; sort(a+1,a+1+n,cmp); for (ll i=1;i<=n;i++) b[a[i].b].a=i,b[a[i].b].b=a[i].a; ll j; sort(a+1,a+1+n,cmp); b[0].a=-1; b[n+1].a=-1; for (ll i=1;i<=n;i++) { j=i-1; while (b[i].a<b[j].a) j=lr[j]; lr[i]=j; } for (ll i=n;i>=1;i--) { j=i+1; while (b[i].a<b[j].a) j=rr[j]; rr[i]=j; } ll ans=0; for (ll i=1;i<=n;i++) if (b[i].b>=kk) ans+=(b[i].b-kk+1)*(i-lr[i])*(rr[i]-i); return(ans); } string s1,s2,stmp; int main() { freopen("noip.in","r",stdin); freopen("noip.out","w",stdout); while (cin>>kk&&kk) { cin>>s1>>s2; string str=s1; memset(s,0,sizeof(s)); strcpy(s,str.c_str()); n=strlen(s); for (ll i=1;i<=n;i++) a[i]=s[i-1]-' '; ll x1,x2,x3; asa(n,200000); x1=get_ans(); stmp=s2; memset(s,0,sizeof(s)); strcpy(s,stmp.c_str()); n=strlen(s); for (ll i=1;i<=n;i++) a[i]=s[i-1]-' '; asa(n,200000); x2=get_ans(); stmp=s1+'%'+s2; memset(s,0,sizeof(s)); strcpy(s,stmp.c_str()); n=strlen(s); for (ll i=1;i<=n;i++) a[i]=s[i-1]-' '; asa(n,200000); x3=get_ans(); cout<<x3-x2-x1<<endl; } return 0; }