求出后缀数组, 对于每个询问二分出左右端点, 离线之后用树状数组求就好了。
或者不建st表, 用并查集求出左右端点, 这样空间和常数都更优。
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 5e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;} mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); int r[N], sa[N], _t[N], _t2[N], c[N], rk[N], lcp[N], san; int maxc = 'z' + 1; void buildSa(int *r, int n, int m) { int i, j = 0, k = 0, *x = _t, *y = _t2; for(i = 0; i < m; i++) c[i] = 0; for(i = 0; i < n; i++) c[x[i] = r[i]]++; for(i = 1; i < m; i++) c[i] += c[i - 1]; for(i = n - 1; i >= 0; i--) sa[--c[x[i]]] = i; for(int k = 1; k <= n; k <<= 1) { int p = 0; for(i = n - k; i < n; i++) y[p++] = i; for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k; for(i = 0; i < m; i++) c[i] = 0; for(i = 0; i < n; i++) c[x[y[i]]]++; for(i = 1; i < m; i++) c[i] += c[i - 1]; for(i = n - 1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i]; swap(x, y); p = 1; x[sa[0]] = 0; for(int i = 1; i < n; i++) { if(y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k]) x[sa[i]] = p - 1; else x[sa[i]] = p++; } if(p >= n) break; m = p; } for(i = 1; i < n; i++) rk[sa[i]] = i; for(i = 0; i < n - 1; i++) { if(k) k--; j = sa[rk[i] - 1]; while(r[i + k] == r[j + k]) k++; lcp[rk[i]] = k; } } struct Bit { int a[N]; void modify(int x, int v) { for(int i = x; i < N; i += i & -i) { a[i] += v; } } int sum(int x) { int ans = 0; for(int i = x; i; i -= i & -i) { ans += a[i]; } return ans; } int query(int L, int R) { if(L > R) return 0; return sum(R) - sum(L - 1); } } bit; int fa[N], dl[N], dr[N]; int getRoot(int x) { return fa[x] == x ? x : fa[x] = getRoot(fa[x]); } void Merge(int u, int v) { int x = getRoot(u); int y = getRoot(v); if(x == y) return; chkmin(dl[x], dl[y]); chkmax(dr[x], dr[y]); fa[y] = x; } struct Qus { int vl, vr, id; }; int n, q, belong[N], len[N], sid[N]; int ans[N]; char s[N]; vector<Qus> Q[N]; int id[N], L[N], R[N], lcpId[N]; void printSuf(int x) { for(int i = sa[x]; i < san; i++) putchar((char)r[i]); for(int i = 0; i < (sa[x] + 5); i++) putchar(' '); printf("i: %3d sa: %3d lcp: %3d belong: %3d ", x, sa[x], lcp[x], belong[sa[x]]); } int main() { scanf("%d%d", &n, &q); for(int i = 1; i <= n; i++) { if(i > 1) r[san++] = maxc++; scanf("%s", s); len[i] = strlen(s); sid[i] = san; for(int j = 0; s[j]; j++) { r[san] = s[j]; belong[san++] = i; } id[i] = i; } r[san] = 0; buildSa(r, san + 1, maxc); for(int i = 1; i <= san; i++) { fa[i] = i; dl[i] = dr[i] = i; lcpId[i] = i; } sort(lcpId + 1, lcpId + san + 1, [&](int x, int y) { return lcp[x] > lcp[y]; }); sort(id + 1, id + 1 + n, [&](int x, int y) { return len[x] > len[y]; }); for(int i = 1, j = 1; i <= n; i++) { while(j <= san && lcp[lcpId[j]] >= len[id[i]]) { Merge(lcpId[j], lcpId[j] - 1); j++; } int Rt = getRoot(rk[sid[id[i]]]); L[id[i]] = dl[Rt]; R[id[i]] = dr[Rt]; } // puts(""); // for(int i = 1; i <= san; i++) { // printSuf(i); // } for(int i = 1; i <= q; i++) { int l, r, k; scanf("%d%d%d", &l, &r, &k); Q[L[k] - 1].push_back(Qus{l, r, -i}); Q[R[k]].push_back(Qus{l, r, i}); } for(int i = 1; i <= san; i++) { if(belong[sa[i]]) { bit.modify(belong[sa[i]], 1); } for(auto &q : Q[i]) { if(q.id > 0) ans[q.id] += bit.query(q.vl, q.vr); else ans[-q.id] -= bit.query(q.vl, q.vr); } } for(int i = 1; i <= q; i++) { printf("%d ", ans[i]); } return 0; } /* */