70分算法:我们把所有的串都添加到(AC)自动机,然后按照(y)排序,时间复杂度(O(n^2))。
我们考虑一些优化,
我们求的其实就是有多少个(x)串是(y)串的前缀的后缀。
其实就是(fail)指针的检索结构,
然后我们可以建出(fail)树,
然后转换一下询问。
其实就是这个东西:
把所有(y)串的节点的前缀打上标记,然后在所有(x)串节点的子树里统计个数,就是答案。
维护个数和其实可以用(fail)树上求个(dfs)序然后使用线段树/树状数组完成。
可是这样依然会超时。
我们考虑继续优化。
我们将y串使用原(trie)树上的(dfs)序。
因为按照(dfs)遍历每个点最多只会进出一次,所以时间复杂度是(O(n log n))
代码
#include <bits/stdc++.h>
const int maxn = 1e5 + 10;
int n, m, i, j, k, tim1, tim2, cnty, tot, tmp;
int ch[maxn][26], fail[maxn], fa[maxn], id[maxn];
std::vector<int> vec[maxn];
int dfn1[maxn], lca[maxn], rnk[maxn], c[maxn];
int dfn2[maxn], tl[maxn], tr[maxn];
int hd[maxn], ver[maxn], nxt[maxn], ans[maxn], cnte;
char s[maxn];
struct query {
int x, y, id;
query() { x = y = id = 0; }
query(int _x,int _y,int _id) {
x = _x; y = _y; id = _id;
}
inline friend bool operator < (query a,query b) {
return rnk[a.y] < rnk[b.y];
}
} q[maxn];
inline void add(int x,int y) {
int n = tot + 1;
while(x <= n)
c[x] += y, x += x & -x;
}
inline int ask(int x) {
int res = 0;
while(x)
res += c[x], x -= x & -x;
return res;
}
inline void adde(int u,int v) {
ver[++cnte] = v; nxt[cnte] = hd[u];
hd[u] = cnte; return;
}
inline void get_fail() {
std::queue<int> q; fail[0] = 0;
for(int i = 0;i <= 25;i++) {
if(ch[0][i]) {
fail[ ch[0][i] ] = 0;
q.push(ch[0][i]);
}
}
while(!q.empty()) {
int u = q.front(); q.pop();
for(int i = 0;i <= 25;i++) {
if(ch[u][i]) {
fail[ ch[u][i] ] = ch[ fail[u] ][i];
q.push(ch[u][i]);
} else {
ch[u][i] = ch[ fail[u] ][i];
}
}
}
}
inline void ac_init() {
get_fail();
memset(hd,-1,sizeof(hd));
for(int i = 1;i <= tot;i++)
adde(fail[i],i);
return;
}
void get_dfn1(int u) {
int len = vec[u].size();
for(int i = 0;i < len;i++) {
int x = vec[u][i];
dfn1[++tim1] = x;
rnk[x] = tim1;
id[x] = u;
lca[tim1] = tmp;
tmp = u;
}
for(int i = 0;i <= 25;i++) {
int v = ch[u][i];
if(!v)
continue;
get_dfn1(v);
tmp = u;
}
}
void get_dfn2(int u) {
dfn2[u] = ++tim2; tl[u] = tim2;
for(int i = hd[u];~i;i = nxt[i]) {
int v = ver[i];
get_dfn2(v);
}
tr[u] = tim2;
}
int main() {
scanf("%s",s + 1);
scanf("%d",&n);
for(int i = 1;i <= n;i++)
scanf("%d %d",&q[i].x,&q[i].y), q[i].id = i;
int u = 0;
for(int i = 1, l = strlen(s + 1);i <= l;i++) {
if(s[i] == 'P')
vec[u].push_back(++cnty);
else if(s[i] == 'B')
u = fa[u];
else {
int c = s[i] - 'a';
if(!ch[u][c])
ch[u][c] = ++tot;
fa[ ch[u][c] ] = u;
u = ch[u][c];
}
}
get_dfn1(0);
ac_init();
get_dfn2(0);
std::sort(q + 1,q + n + 1);
u = 0;
for(int i = 1, j = 1;i <= cnty;i++) {
int x = dfn1[i], v = id[x];
while(u != lca[i])
add(dfn2[u],-1), u = fa[u];
while(v != u)
add(dfn2[v],1), v = fa[v];
u = id[x];
while(q[j].y == x) {
int w = id[ q[j].x ];
ans[ q[j].id ] = ask(tr[w]) - ask(tl[w] - 1);
j++;
}
}
for(int i = 1;i <= n;i++)
printf("%d
",ans[i]);
return 0;
}