#include<cstdio> #include<cstring> #define LL long long char st[500010],stt[500010]; LL ans; int leafr,ndcnt=1,sidcnt,lastcre; int isleaf[1000010],dep[1000010],nd[1000010],l[1000010],r[1000010],des[1000010], next[1000010],sufchain[1000010]; int getr(int sid){ if (isleaf[des[sid]]) return(leafr);else return(r[sid]); } void cut(int po,int sid,int len){ dep[++ndcnt]=dep[po]+len; isleaf[ndcnt]=0; nd[ndcnt]=++sidcnt; des[sidcnt]=des[sid]; l[sidcnt]=len+l[sid];r[sidcnt]=getr(sid);next[sidcnt]=-1;des[sidcnt]=des[sid]; des[sid]=ndcnt;r[sid]=len+l[sid]-1; } void ins(int po,int alph){ dep[++ndcnt]=dep[po]+1; isleaf[ndcnt]=1; nd[ndcnt]=-1; l[++sidcnt]=alph;r[sidcnt]=alph; des[sidcnt]=ndcnt;next[sidcnt]=nd[po];nd[po]=sidcnt; } void buildtree(){ int po=1,lastins=1;nd[1]=-1;sufchain[1]=1;int prom=0; for (int i=1;i<=strlen(st)-1;i++){ leafr=i;int p,flag; prom=0; while (lastins<=i){ while(1){ flag=0; for (p=nd[po];p!=-1;p=next[p]) if (st[l[p]]==st[lastins+dep[po]]) break; if (p==-1){ins(po,i); if (prom) sufchain[lastcre]=po; prom=0; lastcre=po; break;} if (i-lastins+1-dep[po]<=getr(p)-l[p]+1){ if (st[i]==st[l[p]+i-lastins-dep[po]]) {flag=1; if (prom) sufchain[lastcre]=po; prom=0; lastcre=po; break;} if (st[i]!=st[l[p]+i-lastins-dep[po]]) { cut(po,p,i-lastins-dep[po]); if (prom) sufchain[lastcre]=ndcnt; if (i-lastins-dep[po]>1) prom=1;else prom=0; lastcre=ndcnt;sufchain[lastcre]=1; ins(ndcnt,i);break; } } if (i-lastins+1-dep[po]>getr(p)-l[p]+1) po=des[p]; } if (flag) {break;} po=sufchain[po];lastins++; } } } int top=0; struct str{ int po,fr; LL cnt,len; }sta[500010]; void pb(int ps,int le,int p){ sta[++top].po=p;sta[top].len=le;sta[top].fr=ps; sta[top].cnt=0; } void pop(){ if (isleaf[sta[top].po]) sta[top].cnt++; sta[sta[top].fr].cnt+=sta[top].cnt; top--; } void getans(){ pb(0,0,1);top=1; int i=1; while (top>=i){ int tp=i; for (int p=nd[sta[tp].po];p!=-1;p=next[p]) pb(tp,getr(p)-l[p]+1,des[p]); i++; } while(top){ ans-=sta[top].len*sta[top].cnt*(sta[top].cnt-1); pop(); } } int main(){ scanf("%s",stt); LL len=strlen(stt);ans=(len-1)*len*(len+1)/2; st[0]=' '; strcat(stt,"$");strcat(st,stt); buildtree(); getans(); printf("%lld",ans); }
注意某一个串对应点到后缀树根的深度是$O(n cdot sqrt{n})$的