【题意】给定只含小写字母的字符串s,定义价值为回文子串的长度*出现次数,求最大价值。n<=3*10^5。
【算法】回文树
【题解】回文树上一个点的被访问次数是其作为最长回文子串的出现次数。
将fail边反向连接建树后,每个点的子树访问次数和就是这个回文子串的出现次数,可以dfs解决。
注意:要从-1点开始dfs才能保证到达所有点。记得开long long。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int maxn=600010; int n,tot,len[maxn],fail[maxn],length,sz,nownode,ch[maxn][30],first[maxn],size[maxn]; ll ans=0; struct edge{int v,from;}e[maxn]; char s[maxn]; int getfail(int x){while(s[length-len[x]-1]!=s[length])x=fail[x];return x;} void build(){ int y=s[++length]-'a'; int x=getfail(nownode); if(!ch[x][y]){ len[++sz]=len[x]+2; fail[sz]=ch[getfail(fail[x])][y]; ch[x][y]=sz; } size[ch[x][y]]++; nownode=ch[x][y]; } void insert(int u,int v){tot++;e[tot].v=v;e[tot].from=first[u];first[u]=tot;} int dfs(int x){ ll sum=size[x]; for(int i=first[x];i;i=e[i].from){ sum+=dfs(e[i].v); } ans=max(ans,sum*len[x]); return sum; } int main(){ scanf("%s",s+1); n=strlen(s+1); len[0]=0;fail[0]=1; len[1]=-1;fail[1]=1; length=0; sz=1;//!!! for(int i=1;i<=n;i++)build(); insert(1,0); for(int i=2;i<=sz;i++){ insert(fail[i],i); } dfs(1);//1 printf("%lld",ans); return 0; }
补充做法:
【算法】manacher+后缀自动机
【题解】用manacher求出本质不同的回文串(右端位置和长度)。
对于每个回文串在后缀自动机找到对应节点:每个点在parent树上倍增,先找到回文串右端位置所在的前缀开端节点,然后倍增到Len恰好合适的位置。
对应节点的Right集合大小就是出现次数。
不会manacher,就当练习一下PAM+SAM了。(然后被BZOJ卡空间,在UOJ过的)
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int maxn=600010; struct tree{int len,fa,t[26];}t[maxn]; int a[maxn],b[maxn],d[maxn],n,m,tot,last,r[maxn],root,w[maxn],e[maxn],f[maxn][21]; char s[maxn]; namespace hw{ int len[maxn],ch[maxn>>1][26],fail[maxn],last,nowlen,sz; int getfail(int x){while(s[nowlen-len[x]-1]!=s[nowlen])x=fail[x];return x;}//while,if void ins(){ int c=s[++nowlen]-'a'; int x=getfail(last); if(!ch[x][c]){ len[++sz]=len[x]+2; fail[sz]=ch[getfail(fail[x])][c]; ch[x][c]=sz; a[sz]=nowlen;b[sz]=len[sz]; } last=ch[x][c]; } int solve(){ len[0]=0;fail[0]=1; len[1]=-1;fail[1]=1; sz=1;nowlen=last=0; for(int i=1;i<=n;i++)ins(); return sz; } } void insert(int c){ int np=++tot; t[np].len=t[last].len+1; d[t[np].len]=np;r[np]=1; int x=last; last=np; while(x&&!t[x].t[c])t[x].t[c]=np,x=t[x].fa; if(!x)t[np].fa=root;else{ int y=t[x].t[c]; if(t[y].len==t[x].len+1)t[np].fa=y;else{ int nq=++tot; t[nq]=t[y]; t[nq].len=t[x].len+1; t[nq].fa=t[y].fa;t[y].fa=t[np].fa=nq; while(x&&t[x].t[c]==y)t[x].t[c]=nq,x=t[x].fa; } } } int main(){ scanf("%s",s+1);n=strlen(s+1); m=hw::solve(); tot=last=root=1; for(int i=1;i<=n;i++)insert(s[i]-'a'); for(int i=1;i<=tot;i++)w[t[i].len]++; for(int i=1;i<=n;i++)w[i]+=w[i-1]; for(int i=1;i<=tot;i++)e[w[t[i].len]--]=i; for(int o=tot;o>=1;o--){ int i=e[o]; r[t[i].fa]+=r[i]; } for(int i=1;i<=tot;i++)f[i][0]=t[i].fa; for(int j=1;j<=20;j++){ for(int i=1;i<=tot;i++)f[i][j]=f[f[i][j-1]][j-1]; } long long ans=0; for(int i=1;i<=m;i++){ int x=d[a[i]]; for(int j=20;j>=0;j--)if(f[x][j]&&t[f[x][j]].len>=b[i])x=f[x][j];//0 ans=max(ans,1ll*b[i]*r[x]); } printf("%lld",ans); return 0; }