定义
后缀平衡树,就是动态的维护后缀数组,可以 (O(log n)) 在末尾插入字符,(O(log n)) 查询 (rank,SA)。但是由于是维护的后缀信息,所以插入只能在末尾插入字符(然后转化成在开头加一个字符),相当于添加一个后缀。
在线构造
方法一:
我们需要一种能比较两个后缀大小的方法,最简单就是二分+Hash,(O(log n)) 实现,再加上平衡树插入复杂度,总复杂度 (O(log^2 n))。
方法二:
考虑另一种比较方法,因为每次只添加一个字符,也就是说如果把第一个字符删掉,那剩下的字符串在之前已经插入过后缀平衡树中了,我们只需要先比较一下两个后缀的第一个字符,后面字符串的比较直接调用之前信息就好。
那怎么快速比较后缀平衡树中两个后缀大小呢?我们给每个点一个权值区间 ([l,r]),定义这个点的权值 (tag_i) 为 (mid=frac{l+r}2)。那它左子树对应的区间就是 ([l,mid-1]),右子树对应的区间就是 ([mid+1,r])。发现如果按照中序遍历的顺序遍历整颗平衡树,那每个点的权值是单调递增的。
可是这是平衡树诶。如果是那种基于旋转重构的平衡树那岂不是每次旋转都要重构一遍子树内的权值? (emmm) 确实是这样,所以要用到一种更高级的平衡树---重量平衡树。重量平衡树就是要保证平衡不能是均摊平衡,然后要么没有旋转,要么旋转影响的子树大小是期望(log)或者均摊(log)。(Treap) 和替罪羊树都满足这个条件。所以我们直接拿 (Treap) 维护就好了。
代码实现
嗯以上就是基本概念了,然后讲一下怎么维护好我们需要的 (rank,sa,height) 数组。
再次强调这里把每次往末尾插入一个元素转化成每次往开头插入一个元素
回顾最开始说的比较方法,如果两个字符串首字符不一样那么直接比较,否则比较去掉首字符后两个字符串的 (tag) 值。代码长这样:
bool cmp(int x,int y){ // 比较第x个插入的后缀和第y个插入的后缀哪个字典序小 x<y返回1
return s[x]<s[y] or s[x]==s[y] and tag[x-1]<tag[y-1];
}
如果需要往平衡树里插入字符c,设当前要插入的元素是第 (tot) 个元素,即 (s[tot]=c)。先把 (tot) 扔进去,找到 (tot) 的前驱后继,也就是它在 (sa) 数组上的前驱后继 (pre,nxt)。因为要维护好 (height),之前的 (height[nxt]=lcp(pre,nxt)) ,如果要往中间插入一个 (tot) 的话,那就需要让 (height[tot]=lcp(pre,tot),height[nxt]=lcp(tot,nxt))。这样就维护好了 (height) 数组。
(sa) 和 (rank) 比较简单,一个是找第 (k) 大,一个是找排名为 (k),都是平衡树的基本操作了。
还有就是如果要删除怎么办。同样找到 $pre,nxt $ ,令 (height[nxt]=lcp(pre,nxt)) 即可。然后在平衡树上删除点 (tot) 的时候也要注意一下,如果当前已经找到了 (tot),那就不用像普通的 (Treap) 一样旋转到叶子结点再删除,因为每次旋转的时候都需要遍历整个子树然后重构权值,所以这里直接像 (fhq\_Treap) 一样把 (tot) 的两个孩子 (merge) 起来,最后遍历一遍就好了。删除代码:
void remove(int &x,int l,int r){
if(x==tot){
x=merge(ch[x][0],ch[x][1]);
dfs(x,l,r);return;
} else{
sze[x]--;
int mid=l+r>>1;
if(cmp(x,tot)) remove(ch[x][1],mid+1,r);
else remove(ch[x][0],l,mid-1);
}
}
同时因为我们需要二分+哈希维护 (height),所以还需要动态维护哈希值。
别的就没啥了。
也不知道有啥用。
Code
一道例题 要求往字符串末尾插入一个字符,撤销一个插入,询问当前字符串本质不同子串个数。
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
#define int long long
const int N=1e5+5;
const int inf=1e18;
const int base=9973;
char s[N];
int lcp[N],ans,prio[N];
int root,tag[N],ch[N][2];
int n,tot,sze[N];;ull hsh[N],pw[N];
int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
}
bool cmp(int x,int y){
return s[x]<s[y] or s[x]==s[y] and tag[x-1]<tag[y-1];
}
void dfs(int x,int l,int r){
if(!x) return;
int mid=l+r>>1;
tag[x]=mid;
dfs(ch[x][0],l,mid-1),dfs(ch[x][1],mid+1,r);
sze[x]=sze[ch[x][0]]+sze[ch[x][1]]+1;
}
void rotate(int &x,int d,int l,int r){
int y=ch[x][d],z=ch[y][d^1];
ch[y][d^1]=x;ch[x][d]=z;
x=y;dfs(x,l,r);
}
void insert(int &x,int l,int r){
if(!x) {
x=tot;tag[x]=l+r>>1;
sze[x]=1;ch[x][0]=ch[x][1]=0;
prio[x]=rand();return;
} int d=cmp(x,tot),mid=l+r>>1;
sze[x]++;
if(d) insert(ch[x][d],mid+1,r);
else insert(ch[x][d],l,mid-1);
if(prio[ch[x][d]]<prio[x]) rotate(x,d,l,r);
}
int find(int x,int now){
if(x==now) return sze[ch[x][0]]+1;
int d=cmp(x,now);
if(d) return sze[ch[x][0]]+1+find(ch[x][1],now);
else return find(ch[x][0],now);
}
int kth(int x,int k){
if(!x) return 0;
if(sze[ch[x][0]]==k-1) return x;
if(sze[ch[x][0]]>=k) return kth(ch[x][0],k);
return kth(ch[x][1],k-sze[ch[x][0]]-1);
}
bool eq(int l1,int l2,int len){
ull a=hsh[l1+len-1]-hsh[l1-1]*pw[len],b=hsh[l2+len-1]-hsh[l2-1]*pw[len];
return a==b;
}
int getlcp(int a,int b){
int l=1,r=min(a,b),ans=0;
while(l<=r){
int mid=l+r>>1;
if(eq(a-mid+1,b-mid+1,mid)) ans=mid,l=mid+1;
else r=mid-1;
} return ans;
}
void ins(int x){
s[++tot]=s[x];hsh[tot]=hsh[tot-1]*base+s[x]-'a';
insert(root,1,inf);
int a=find(root,tot),b=kth(root,a-1),c=kth(root,a+1);
ans-=lcp[c];
lcp[tot]=getlcp(b,tot),lcp[c]=getlcp(tot,c);
ans+=lcp[tot]+lcp[c];
}
int merge(int x,int y){
if(!x or !y) return x+y;
if(prio[x]<prio[y]) {
ch[x][1]=merge(ch[x][1],y);
return x;
}
else {
ch[y][0]=merge(x,ch[y][0]);
return y;
}
}
void remove(int &x,int l,int r){
if(x==tot){
x=merge(ch[x][0],ch[x][1]);
dfs(x,l,r);return;
} else{
sze[x]--;
int mid=l+r>>1,d=cmp(x,tot);
if(d) remove(ch[x][d],mid+1,r);
else remove(ch[x][d],l,mid-1);
}
}
void del(){
int rk=find(root,tot);
int b=kth(root,rk-1),c=kth(root,rk+1);
ans-=lcp[tot]+lcp[c];
lcp[c]=getlcp(b,c);
ans+=lcp[c];
remove(root,1,inf);
tot--;
}
signed main(){
srand(20020619);
scanf("%s",s+1);n=strlen(s+1);
pw[0]=1;for(int i=1;i<=n;i++) pw[i]=pw[i-1]*base;
for(int i=1;i<=n;i++){
if(s[i]=='-') del();
else ins(i);
printf("%lld
",tot*(tot+1)/2-ans);
} return 0;
}