首先可以想到一个暴力的(mathcal{O(n^3)})算法:枚举( ext{A}),( ext{B})的两个后缀,算出他们的最长公共前缀。
这样显然是对的,但是也显然可以用后缀数组优化。
把( ext{A}),( ext{B})两个串用一个没出现过的字符隔开然后连起来,对新串求后缀数组。那么对于原来的两个后缀,也可以表现为在这个串里对应位置后缀的LCP,也就是区间 height 的 min.
那么区间 height 的 min 就启发我们用单调栈来枚举这个 min 然后算贡献,对于这个min,左端点的选取是一个区间,右端点也是,故统计这两段内有多少对后缀在原串中一个在A,一个在B即可。
这个东西某个 sb 一开始写了主席树求,后来发现直接前缀和一下就可以了。
所以复杂度是(mathcal{O(n log_2 n)}),瓶颈在求sa。
具体为啥我调了一天,大概是因为sa求错了一直wa10...
#include <bits/stdc++.h>
#pragma GCC optimize("Ofast","-funroll-loops","-fdelete-null-pointer-checks")
#pragma GCC target("ssse3","sse3","sse2","sse","avx2","avx")
#define rep(i, l, r) for (int i = (l); i <= (r); ++i)
#define per(i, r, l) for (int i = (r); i >= (l); --i)
using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
typedef vector <int> vi;
int gi() {
int f = 1, x = 0; char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return f * x;
}
const int N = 400005;
char s1[N >> 1], s2[N >> 1], s[N];
int n, sa[N], rk[N], id[N], px[N], cnt[N], rk_[N], h[N], l[N], r[N], st[N], tp;
bool cmp(int x, int y, int w) {
return rk_[x] == rk_[y] && rk_[x + w] == rk_[y + w];
}
void get_SA() {
int i, j, m = 300, p, w;
for (i = 1; i <= n; ++i) ++cnt[rk[i] = s[i]];
for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
for (i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i;
for (w = 1; w < n; w <<= 1, m = p) {
for (p = 0, i = n; i > n - w; --i) id[++p] = i;
for (i = 1; i <= n; ++i) if (sa[i] > w) id[++p] = sa[i] - w;
memset (cnt, 0, sizeof(cnt));
for (i = 1; i <= n; ++i) ++cnt[px[i] = rk[id[i]]];
for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
for (i = n; i >= 1; --i) sa[cnt[px[i]]--] = id[i];
memcpy(rk_, rk, sizeof(rk));
for (p = 0, i = 1; i <= n; ++i)
rk[sa[i]] = (cmp(sa[i - 1], sa[i], w) ? p : ++p);
}
for (i = 1, j = 0; i <= n; ++i) {
if (rk[i] == 1) continue;
int now = sa[rk[i] - 1];
while (s[i + j] == s[now + j]) ++j;
h[rk[i]] = j;
if (j) --j;
}
}
int S1[N], S2[N];
int main() {
scanf("%s%s",s1 + 1, s2 + 1);
int len = strlen(s1 + 1), len_ = strlen(s2 + 1);
n = len + len_ + 1;
rep (i, 1, n) {
if (i <= len) s[i] = s1[i];
else if (i == len + 1) s[i] = '$';
else s[i] = s2[i - len - 1];
}
get_SA();
rep (i, 1, n) l[i] = 0, r[i] = n + 1;
rep (i, 1, n) {
while (tp && h[st[tp]] >= h[i]) --tp;
if (tp) l[i] = st[tp];
st[++tp] = i;
}
tp = 0;
per (i, n, 1) {
while (tp && h[st[tp]] > h[i]) --tp;
if (tp) r[i] = st[tp];
st[++tp] = i;
}
ll ans = 0;
rep (i, 1, n) {
S1[i] = S1[i - 1] + (1 <= sa[i] && sa[i] <= len);
S2[i] = S2[i - 1] + (len + 2 <= sa[i] && sa[i] <= n);
}
rep (i, 1, n) {
int l_ = l[i] + 1, r_ = r[i] - 1;
ans += 1ll * h[i] * (S1[i - 1] - S1[l_ - 2]) * (S2[r_] - S2[i - 1]);
ans += 1ll * h[i] * (S2[i - 1] - S2[l_ - 2]) * (S1[r_] - S1[i - 1]);
}
cout << ans << '
';
return 0;
}