「NOI2016」优秀的拆分
这不是个SAM题,只是个LCP题目
95分的Hash很简单,枚举每个点为开头和末尾的AA串个数,然后乘一下之类的。
考虑怎么快速求“每个点为开头和末尾的AA串个数”
考虑枚举A的长度,然后在序列中每|A|个位置放一个关键点,这样每个AA至少都经过了一个关键点。
然后求相邻两个关键点的lcs,lcp,画画图匹配一下,可以把区间内的都求出来了。
可以Hash二分或者sa或者sam
Code:
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using std::min;
const int N=120010;
struct SAM
{
int head[N],to[N],Next[N],cnt;
void add(int u,int v){to[++cnt]=v,Next[cnt]=head[u],head[u]=cnt;}
int dfn[N],st[18][N],Log[N],dep[N],pos[N],clock;
void dfs(int now)
{
st[0][dfn[now]=++clock]=now;
for(int i=head[now];i;i=Next[i])
dep[to[i]]=dep[now]+1,dfs(to[i]),st[0][++clock]=now;
}
int LCA(int x,int y)
{
x=dfn[x],y=dfn[y];
if(x>y) std::swap(x,y);
int d=Log[y+1-x];
x=st[d][x],y=st[d][y-(1<<d)+1];
return dep[x]<dep[y]?x:y;
}
int len[N],par[N],ch[N][26],las=1,tot=1;
void extend(int c)
{
int now=++tot,p=las;
len[now]=len[p]+1;
while(p&&!ch[p][c]) ch[p][c]=now,p=par[p];
if(!p) par[now]=1;
else
{
int x=ch[p][c];
if(len[x]==len[p]+1) par[now]=x;
else
{
int y=++tot;
len[y]=len[p]+1;
par[y]=par[x];
memcpy(ch[y],ch[x],sizeof ch[y]);
while(p&&ch[p][c]==x) ch[p][c]=y,p=par[p];
par[x]=par[now]=y;
}
}
las=now;
}
void init(char *s,int typ)
{
int n=strlen(s+1);
for(int i=1;i<=n;i++) extend(s[i]-'a'),pos[typ?n+1-i:i]=las;
for(int i=1;i<=tot;i++) add(par[i],i);
clock=0,dep[1]=1,dfs(1);
Log[0]=-1;for(int i=1;i<=clock;i++) Log[i]=Log[i>>1]+1;
for(int j=1;j<=17;j++)
for(int i=1;i<=clock-(1<<j)+1;i++)
{
int x=st[j-1][i],y=st[j-1][i+(1<<j-1)];
st[j][i]=dep[x]<dep[y]?x:y;
}
}
void clear()
{
memset(ch,0,sizeof ch);
memset(par,0,sizeof par);
memset(len,0,sizeof len);
memset(head,0,sizeof head);
cnt=0,las=tot=1;
}
int query(int x,int y)
{
return len[LCA(pos[x],pos[y])];
}
}LCS,LCP;
int d[N],f[N],g[N];
void work(char *s,int *f)
{
int n=strlen(s+1);
LCS.init(s,0);
std::reverse(s+1,s+n+1);
LCP.init(s,1);
for(int l=1;l<=n;l++)
{
for(int i=1;i+l<=n;i+=l)
{
int a=min(LCS.query(i,i+l),l),b=min(LCP.query(i,i+l),l);
if(a+b-1<l) continue;
int ss=i-a+1,tt=ss+a+b-l-1;
++d[ss+l*2-1],--d[tt+l*2];
}
}
for(int i=1;i<=n;i++) f[i]=f[i-1]+d[i];
memset(d,0,sizeof d);
LCP.clear(),LCS.clear();
}
char s[N];
int main()
{
int T;scanf("%d",&T);
while(T--)
{
memset(f,0,sizeof f);
memset(g,0,sizeof g);
scanf("%s",s+1);
int n=strlen(s+1);
work(s,f);
work(s,g);
std::reverse(g+1,g+1+n);
ll ans=0;
for(int i=1;i<=n;i++) ans=ans+f[i-1]*g[i];
printf("%lld
",ans);
}
return 0;
}
2019.3.15