[BZOJ3756]Pty的字符串(广义SAM)
题面
在神秘的东方有一棵奇葩的树,它有一个固定的根节点(编号为1)。树的每条边上都是一个字符,字符为a,b,c中的一个.
你可以从树上的任意一个点出发,然后沿着远离根的边往下行走,在任意一个节点停止,将你经过的边的字符依次写下来,就能得到一个字符串
现在pty得到了一棵树和一个字符串S。如果S的一个子串[l,r]和树上某条路径所得到的字符串完全相同,则我们称这个子串和该路径匹配。现在pty想知道,S的所有子串和树上的所有路径的匹配总数是多少?
分析
因为要求的路径一定是深度单调的一条链,也就是Trie树上串的后缀。因此,我们可以对这棵树建出广义SAM。
对于每个广义SAM上的节点,我们计算该节点对答案的贡献,记为(sum_x)。可以递推求出:
[sum_x=sum_{link(x)}+(len(x)-len(link(x))) imes |right_x|
]
该式子的意义是:加上最长后缀的答案(sum_{link(x)}),又因为该节点会产生(len(x)-len(link(x)))个本质不同的子串,这些子串的出现位置有(|right_x|)个。广义SAM的right集合代表在Trie上不同位置的出现,求(right)集合大小和普通SAM类似,不再赘述。
然后把(S)放到广义SAM上匹配,如果当前匹配节点为(x),匹配长度为(matlen),那么和上面的递推式类似,答案就要加上(sum_{link(x)}+(matlen-len(link(x))) imes |right_x|).
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define maxn 8000000
#define maxc 26
using namespace std;
typedef long long ll;
template<typename T> inline void qread(T &x){
x=0;
T sign=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-') sign=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
x=x*10+c-'0';
c=getchar();
}
x=x*sign;
}
int n;
vector<int>E[maxn+5];
struct EXSAM{
#define len(x) t[x].len
#define link(x) t[x].link
struct node{
int ch[maxc];
int link;
ll len;
ll sz;//right集合大小
ll sum;//该节点产生的可能匹配的串
}t[maxn+5];
vector<int>par[maxn+5];//parent树
const int root=1;
int ptr=1;
int extend(int last,int c){
if(t[last].ch[c]&&len(last)+1==len(t[last].ch[c])){
t[t[last].ch[c]].sz++;
return t[last].ch[c];
}
int p=last,cur=++ptr,clo,flag=0;
t[cur].sz=1;
len(cur)=len(last)+1;
while(p&&t[p].ch[c]==0){
t[p].ch[c]=cur;
p=link(p);
}
if(p==0) link(cur)=root;
else{
int q=t[p].ch[c];
if(len(p)+1==len(q)) link(cur)=q;
else{
if(len(p)+1==len(cur)) flag=1;
clo=++ptr;
link(clo)=link(q);
len(clo)=len(p)+1;
for(int i=0;i<maxc;i++) t[clo].ch[i]=t[q].ch[i];
link(q)=link(cur)=clo;
while(p&&t[p].ch[c]==q){
t[p].ch[c]=clo;
p=link(p);
}
}
}
if(flag) return clo;
return cur;
}
void dfs1(int x){
for(int i=0;i<(int)par[x].size();i++){
int y=par[x][i];
dfs1(y);
t[x].sz+=t[y].sz;
}
}
void dfs2(int x){
for(int i=0;i<(int)par[x].size();i++){
int y=par[x][i];
t[y].sum=t[x].sum+t[y].sz*(len(y)-len(x));
dfs2(y);
}
}
void build(){
for(int i=2;i<=ptr;i++) par[link(i)].push_back(i);
dfs1(1);
dfs2(1);
}
ll query(char *tp){
int leng=strlen(tp+1);
int matlen=0;
int x=root;
ll ans=0;
for(int i=1;i<=leng;i++){
int c=tp[i]-'a';
while(x&&t[x].ch[c]==0){
x=link(x);
matlen=len(x);
}
if(t[x].ch[c]){
matlen++;//不能写matlen=len(x)+1,否则这样匹配的时候就会炸
x=t[x].ch[c];
ans+=t[link(x)].sum+(matlen-len(link(x)))*t[x].sz;
}else{
matlen=0;
x=root;
}
}
return ans;
}
}T;
char str[maxn+5];
int col[maxn+5];
void dfs(int x,int fa,int nd){
for(int i=0;i<(int)E[x].size();i++){
int y=E[x][i];
if(y!=fa){
dfs(y,x,T.extend(nd,col[y]));
}
}
}
int main(){
char ch[2];
int f;
qread(n);
for(int i=2;i<=n;i++){
qread(f);
scanf("%s",ch);
E[f].push_back(i);
col[i]=ch[0]-'a';
}
dfs(1,0,T.root);
T.build();
scanf("%s",str+1);
printf("%lld
",T.query(str));
}