题目链接: https://www.luogu.com.cn/problem/P4770
SAM好题.
(I)首先我们考虑l = 1,r = |S|的情况怎么做
我们要求的是本质不同的子串str的数量,满足str是T的子串,且str不是$S_{l,r}$的子串
容易用补集转化成T本质不同的子串数减去S和T本质不同子串数
第一个问题很平凡,我们考虑第二个问题
我们对S,T分别建自动机,令T在S上面跑匹配,同时按着S的跑法在T自己上面跑匹配(因为T的每个子串都为$SAM_{T}$所接受,所以一定能跑)
对于每个前缀我们都可以求出它和S的最长公共后缀l,及在T上的节点,容易发现这个节点以上的长度<=l的都是本质不同的公共子串,因为可能算重所以先打标记然后Treedp统计(这也是为什么要在T上跑的原因,
因为在S上面跑,每次都要遍历S的parent tree时间复杂度不对)
(II)接下来才是难点,如果l,r任意怎么做
显然对于每个子串都建后缀自动机是不可能的,我们思考我们这个后缀自动机到底干了什么呢?
1.判断有没有tran(p,c)的转移边.
2.判断p这个节点的maxlen和minlen
我们可以发现,只要用线段树合并维护出endpos集合,就可以完成区间的上诉两个问题.
int u = get(sam[p].ch[c],l + len,r); if(u){ len++; p = sam[p].ch[c]; x = sam[x].ch[c]; } else{ while(len != -1 && !get(sam[p].ch[c],l + len,r)){ len--; if(len == sam[sam[p].fa].len) p = sam[p].fa; }
其中get(p,l,r)表示p这个节点的endpos集合在[l,r]范围内的最大值
设正在匹配的最长公共子串为s
我们发现我们原本要做的事情是判断s在p这个节点上能不能添上'c'这个字符,即判断 if(sam[p].ch[c] != 0),但是因为有区间限制我们应判断是否存在一个位置x可以接上s+'c',即在[l,r]区间内,是否存在一个endpos(x)满足x - len(s+'c') + 1>= l,即x >= l + len(s+'c') - 1也即x >= len(s) + l,于是只要判断[l+len,r]区间内是否存在endpos集合的元素即可
注意我们若失配此时不应该直接跳fa,而应该先让len自减,要记住这个后缀自动机只是一个框架,是$S_{1,n}$而不是$S_{l,r}$的SAM.
有人可能会问:怎么暴力while怎么能过?
因为数据水? 其实这个时间复杂度是正确的,我们考虑势能分析法,容易发现每次while,len最多减少1,外面for循环每次最多增加1,所以单次匹配时间复杂度是O(|T|logn)的
有很多细节,看代码吧
/*NOI2018[你的名字]*/ #include<bits/stdc++.h> using namespace std; #define ll long long int read(){ char c = getchar(); int x = 0; while(c < '0' || c > '9') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + c - 48,c = getchar(); return x; } const int N = 2e6 + 10; struct SegmentTree{ int lc,rc; int mx; }t[N<<4];/*线段树维护endpos集合*/ int Rt[N],num,n; void pushup(int p){ t[p].mx = max(t[t[p].lc].mx,t[t[p].rc].mx); } void Insert(int &p,int l,int r,int pos){ if(!p) p = ++num; if(l == r){ t[p].mx = max(t[p].mx,pos); return; } int mid = (l + r) >> 1; if(pos <= mid) Insert(t[p].lc,l,mid,pos); else Insert(t[p].rc,mid+1,r,pos); pushup(p); } int merge(int p,int q,int l,int r){ if(!p || !q) return p | q; int u = ++num; int mid = (l + r) >> 1; t[u].lc = merge(t[p].lc,t[q].lc,l,mid); t[u].rc = merge(t[p].rc,t[q].rc,mid + 1,r); pushup(u); return u; } int query(int p,int l,int r,int a,int b){ if(a <= l && b >= r) return t[p].mx; int mid = (l + r) >> 1; int ans = 0; if(a <= mid) ans = max(ans,query(t[p].lc,l,mid,a,b)); if(b > mid) ans = max(ans,query(t[p].rc,mid+1,r,a,b)); return ans; } struct SAM{ int ch[26],len,fa; }sam[N<<1]; int lst = 1,cnt = 1; void ins(int c,int rt){ int p = lst,np = ++cnt;lst = np; sam[np].len = sam[p].len + 1; for(; !sam[p].ch[c]; p = sam[p].fa) sam[p].ch[c] = np; if(!p) sam[np].fa = rt; else{ int q = sam[p].ch[c]; if(sam[q].len == sam[p].len + 1) sam[np].fa = q; else{ int nq = ++cnt; sam[nq] = sam[q]; sam[nq].len = sam[p].len + 1; sam[np].fa = sam[q].fa = nq; for(; sam[p].ch[c] == q; p = sam[p].fa) sam[p].ch[c] = nq; } } } int head[N<<1]; int f[N<<1],tot; struct Edge{ int nxt,point; }edge[N<<1]; void add_edge(int u,int v){ edge[++tot].nxt = head[u]; edge[tot].point = v; head[u] = tot; } char S[N],T[N]; void dfs(int u){ for(int i = head[u]; i ; i = edge[i].nxt){ int v = edge[i].point; dfs(v); f[u] = max(f[u],f[v]); } f[u] = min(f[u],sam[u].len); } void getpos(int u){ for(int i = head[u]; i ; i = edge[i].nxt){ int v = edge[i].point; getpos(v); Rt[u] = merge(Rt[u],Rt[v],1,n); } } bool valid(int u,int len){ return len >= sam[sam[u].fa].len + 1 && len <= sam[u].len; } int get(int u,int l,int r){ if(l > r || !u) return 0; return query(Rt[u],1,n,l,r); } int getlen(int u,int l,int r){ int x = get(u,l,r); return min(sam[u].len,x - l + 1); } ll work(char *s,int rt,int l,int r){ int m = strlen(s+1); int p = 1,len = 0,x = rt; for(int i = rt + 1; i <= cnt; ++i){ add_edge(sam[i].fa,i); } for(int i = 1; i <= m; ++i){ int c = s[i] - 'a'; int u = get(sam[p].ch[c],l + len,r); if(u){ len++; p = sam[p].ch[c]; x = sam[x].ch[c]; } else{ while(len != -1 && !get(sam[p].ch[c],l + len,r)){ len--; if(len == sam[sam[p].fa].len) p = sam[p].fa; } if(len == -1){ p = 1; len = 0; x = rt; } else{ len++; p = sam[p].ch[c]; while((!sam[x].ch[c] || !valid(sam[x].ch[c],len)) && x) x = sam[x].fa; if(!x) x = rt; x = sam[x].ch[c]; } } // cout<<i<<' '<<len<<endl; f[x] = max(f[x],len); } dfs(rt); ll ans = 0; for(int i = rt + 1; i <= cnt; ++i){/*!!!attention*/ if(f[i] > sam[sam[i].fa].len){ // assert(f[i] > sam[sam[i].fa].len); ans += f[i] - sam[sam[i].fa].len; } } for(int i = rt; i <= cnt; ++i) f[i] = 0; return ans; } int main(){ freopen("name.in","r",stdin); freopen("name.out","w",stdout); scanf("%s",S+1); n = strlen(S+1); for(int i = 1; i <= n; ++i){ ins(S[i]-'a',1); Insert(Rt[lst],1,n,i); } for(int i = 2; i <= cnt; ++i){ add_edge(sam[i].fa,i); } getpos(1); int q = read(); while(q--){ scanf("%s",T+1); int l = read(),r = read(); int m = strlen(T+1); int rt = ++cnt; lst = rt; for(int i = 1; i <= m; ++i){ ins(T[i]-'a',rt); } ll ans = 0; for(int i = rt + 1; i <= cnt; ++i){ ans += sam[i].len - sam[sam[i].fa].len; } ans -= work(T,rt,l,r); printf("%lld ",ans); } return 0; }