SAM感觉写起来比SA更直观(?)
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <cstdlib> #include <algorithm> #define ll long long #define N 600005 using namespace std; inline int read(){ int ret=0;char ch=getchar(); bool flag=0; while (ch<'0'||ch>'9'){ flag=ch=='-'; ch=getchar(); } while ('0'<=ch&&ch<='9'){ ret=ret*10-48+ch; ch=getchar(); } return flag?-ret:ret; } inline char readch(){ char ch=getchar(); while (!isalpha(ch)) ch=getchar(); return ch; } struct SAM{ int max0[N],max1[N],min0[N],min1[N]; #define u q[i] struct SAMnode{ int par,mx,go[26],rights; SAMnode(){} SAMnode(int _mx):par(0),mx(_mx),rights(0){ memset(go,0,sizeof(go)); } } t[N]; int last,size; int newnode(int _mx){ t[++size]=SAMnode(_mx); max0[size]=max1[size]=-(min0[size]=min1[size]=2e9); return size; } void clear(){size=0;last=newnode(0);} void extend(char c,int data){ c-='a'; int p=last,np=newnode(t[p].mx+1);t[np].rights=1; max0[np]=min0[np]=data; for (;p&&!t[p].go[c];p=t[p].par) t[p].go[c]=np; if (!p) t[np].par=1; else{ int q=t[p].go[c]; if (t[p].mx+1==t[q].mx) t[np].par=q; else{ int nq=newnode(t[p].mx+1); memcpy(t[nq].go,t[q].go,sizeof(t[q].go)); t[nq].par=t[q].par; t[q].par=t[np].par=nq; for (;p&&t[p].go[c]==q;p=t[p].par) t[p].go[c]=nq; } } last=np; } int v[N],q[N]; void precompute(){ memset(v,0,sizeof(v)); for (int i=1;i<=size;++i) ++v[t[i].mx]; for (int i=1;i<=size;++i) v[i]+=v[i-1]; for (int i=size;i;--i) q[v[t[i].mx]--]=i; t[0].mx=-1; } void Max(int x,int data){ max1[x]=max(max1[x],data); if (max1[x]>max0[x]) swap(max1[x],max0[x]); } void Min(int x,int data){ min1[x]=min(min1[x],data); if (min1[x]<min0[x]) swap(min1[x],min0[x]); } void solve(ll *cnt,ll *maxv){ for (int i=size;i;--i){ if (t[u].rights>1){ maxv[t[u].mx]=max(maxv[t[u].mx],max((ll)max0[u]*max1[u],(ll)min0[u]*min1[u])); ll tmp=(ll)t[u].rights*(t[u].rights-1)/2; cnt[t[u].mx]+=tmp; cnt[t[t[u].par].mx]-=tmp; } t[t[u].par].rights+=t[u].rights; Max(t[u].par,max0[u]); Max(t[u].par,max1[u]); Min(t[u].par,min0[u]); Min(t[u].par,min1[u]); } } } sam; char st[N];int a[N],n; ll cnt[N],ans[N]; int main(){ n=read(); for (int i=1;i<=n;++i) st[i]=readch(); for (int i=1;i<=n;++i) a[i]=read(); sam.clear(); for (int i=n;i;--i) sam.extend(st[i],a[i]); sam.precompute(); memset(ans,128,sizeof(ans)); memset(cnt,0,sizeof(cnt)); sam.solve(cnt,ans); for (int i=n-1;i>=0;--i) cnt[i]+=cnt[i+1]; for (int i=n-1;i>=0;--i) ans[i]=max(ans[i],ans[i+1]); for (int i=n-1;i>=0;--i) if (!cnt[i]) ans[i]=0; for (int i=0;i<n;++i) printf("%lld %lld ",cnt[i],ans[i]); return 0; }