思路:插入、修改操作是splay的模型,然后询问的话就可以二分答案,然后再用splay去判,关键就在于怎么去判断。
可以用字符串hash,splay每个节点维护一个hash域,然后就可以定义一个进制去hash即可二分判断,hash值让其自然溢出即可。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> using namespace std; #define maxn 100005 #define p 37 int n,m; unsigned int power[maxn]; char s[maxn],t[5]; struct splay_tree{ int ch[maxn][2],fa[maxn],note[maxn],size[maxn],root,tot; unsigned int hash[maxn],val[maxn]; void update(int x){ if (!x) return;int ls=ch[x][0],rs=ch[x][1]; size[x]=size[ls]+size[rs]+1; hash[x]=hash[ls]*power[size[rs]+1]+val[x]*power[size[rs]]+hash[rs]; } void newnode(int &x,int v,int bo){val[x=++tot]=v,note[x]=bo;} void build(int &x,int l,int r,int bo){ if (l>r) return;int mid=(l+r)>>1; newnode(x,(mid!=0&&mid!=n+1)?s[mid]-'a':0,bo); build(ch[x][0],l,mid-1,0),build(ch[x][1],mid+1,r,1); fa[ch[x][0]]=fa[ch[x][1]]=x; update(x); } void rotate(int x){ int y=fa[x],z=fa[y],bo=note[x],bo1=note[y]; ch[y][bo]=ch[x][bo^1],fa[ch[x][bo^1]]=y; ch[x][bo^1]=y,fa[y]=x; fa[x]=z;if (bo1!=2) ch[z][bo1]=x; note[x]=bo1,note[y]=bo^1,note[ch[y][bo]]=bo; update(y); } void splay(int x){ while (note[x]!=2){ if (note[x]==note[fa[x]]) rotate(fa[x]); rotate(x); } root=x,update(x); } int find(int pos){ int x=root; while (1){ if (size[ch[x][0]]+1==pos) return splay(x),x; else if (size[ch[x][0]]>=pos) x=ch[x][0]; else pos-=size[ch[x][0]]+1,x=ch[x][1]; } } int suc(int x){x=ch[x][1];while (ch[x][0]) x=ch[x][0];return x;} void insert(int x,int v){ splay(x);int y=suc(x),z;newnode(z,v,0); fa[ch[x][1]]=0,note[ch[x][1]]=2,splay(y); fa[y]=x,note[y]=1,ch[x][1]=y; fa[z]=y,ch[y][0]=z,splay(z); } void change(int x,int v){val[x]=v;splay(x);} unsigned int gethash(int l,int len){ int r=l+len-1,x=find(l),y=find(r+2); splay(x),fa[ch[x][1]]=0,note[ch[x][1]]=2,splay(y); note[y]=1,fa[y]=x,ch[x][1]=y; unsigned int ans=hash[ch[y][0]]; return splay(y),ans; } bool check(int x,int y,int len){return gethash(x,len)==gethash(y,len);} }S; void solve(int x,int y){ int l=1,r=min(S.tot-x-1,S.tot-y-1),ans=0; while (l<=r){ int mid=(l+r)>>1; if (S.check(x,y,mid)) l=mid+1,ans=mid; else r=mid-1; } printf("%d ",ans); } int main(){ scanf("%s",s+1),scanf("%d",&m);power[0]=1,n=strlen(s+1); for (int i=1;i<=100001;i++) power[i]=power[i-1]*p; S.build(S.root,0,n+1,2); while (m--){ int x,y;scanf("%s",t+1); if (t[1]=='I') scanf("%d%s",&x,t+1),S.insert(S.find(x+1),t[1]-'a'); else if (t[1]=='R') scanf("%d%s",&x,t+1),S.change(S.find(x+1),t[1]-'a'); else scanf("%d%d",&x,&y),solve(x,y); } return 0; }