Description
Solution
如果直接暴力的话,可以枚举那个不同的字符在串一和串二里的位置分别是什么,然后算一下他们的(lcp)和(lcs)来更新答案,也就是(sum_{i = 1, j = 1} ^{i <= n, j <= m} lcp(i + 1, j + 1) + lcs(i - 1, j - 1) + 1)。
考虑如何优化复杂度,我们发现其实就是求最大的满足题意的(lcp)和(lcs)的值。那么可以进行启发式合并,按照(height)数组进行排序,每次将相邻的两个后缀所在的集合合并,因为是从大到小枚举所以根据(height)的性质,两个集合里的元素的(lcp)是可以确定的,取较小值就行。然后可以对跨越两个集合的后缀算它们最大的(lcs),发现就是在(rank)数组里查前驱和后继,每次更新答案就行。
还有可能出现完全相等的情况,直接扫一遍(height)数组就行。
合并的时候的细节还是挺多的,我开了三类(set),一类存每个集合里正串的(rank)值,一类存每个集合里满足条件的串一反串后缀的(rank),一类存串二反串后缀的(rank)。
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <set>
using namespace std;
#define IT set<int>::iterator
const int N = 100000;
int l, sa[N * 2 + 50], rk[N * 2 + 50], se[N * 2 + 50], tong[N * 2 + 50], cnt, st1[N * 2 + 50][21], rk2[N * 2 + 50], l1, l2, height[N * 2 + 50], lg[N * 2 + 50], ans, tmpans, id[N * 2 + 50], col[N * 2 + 50], tot, fa[N * 2 + 50], siz[N * 2 + 50], w[N * 2 + 50];
char st[N * 2 + 50], tmp[N + 50];
set<int> se1[N * 2 + 50], se2[N * 2 + 50], se3[N * 2 + 50];
void Rsort()
{
for (int i = 0; i <= cnt; i++) tong[i] = 0;
for (int i = 1; i <= l; i++) tong[rk[i]]++;
for (int i = 1; i <= cnt; i++) tong[i] += tong[i - 1];
for (int i = l; i >= 1; i--) sa[tong[rk[se[i]]]--] = se[i];
return;
}
void SA()
{
cnt = 256;
for (int i = 1; i <= l; i++) rk[i] = st[i], se[i] = i;
Rsort();
cnt = 0;
for (int k = 1; cnt < l; k <<= 1)
{
cnt = 0;
for (int i = l; i >= l - k + 1; i--) se[++cnt] = i;
for (int i = 1; i <= l; i++) if (sa[i] > k) se[++cnt] = sa[i] - k;
Rsort();
for (int i = 1; i <= l; i++) se[i] = rk[i];
rk[sa[1]] = cnt = 1;
for (int i = 2; i <= l; i++)
rk[sa[i]] = se[sa[i]] == se[sa[i - 1]] && se[sa[i] + k] == se[sa[i - 1] + k] ? cnt : ++cnt;
}
for (int i = 1; i <= l; i++)
{
height[rk[i]] = height[rk[i - 1]] - 1;
if (height[rk[i]] < 0) height[rk[i]] = 0;
while (st[i + height[rk[i]]] == st[sa[rk[i] - 1] + height[rk[i]]]) height[rk[i]]++;
}
return;
}
void MakeST(int stt[N * 2 + 50][21])
{
for (int i = 1; i <= l; i++) stt[i][0] = height[i];
for (int j = 1; j <= lg[l]; j++)
for (int i = 1; i + (1 << j) - 1 <= l; i++)
stt[i][j] = min(stt[i][j - 1], stt[i + (1 << (j - 1))][j - 1]);
return;
}
int Find(int x)
{
return fa[x] == x ? fa[x] : fa[x] = Find(fa[x]);
}
int Query(int stt[N * 2 + 50][21], int pos1, int pos2)
{
if (pos1 > pos2) swap(pos1, pos2);
pos1++;
int length = lg[pos2 - pos1 + 1];
return min(stt[pos1][length], stt[pos2 - (1 << length) + 1][length]);
}
int Cmp(int a, int b)
{
return height[a] > height[b];
}
void Merge(int posa, int posb)
{
int pa = Find(posa), pb = Find(posb), tmp = height[posb];
if (siz[pa] > siz[pb]) swap(pa, pb);
if (siz[pb] == 1 && siz[pa] == 1) w[pb] = tmp;
else
{
if (siz[pa] == 1) w[pb] = min(w[pb], tmp);
else if (siz[pb] == 1) w[pb] = min(w[pa], tmp);
else w[pb] = min(w[pb], min(tmp, w[pa]));
}
if (col[sa[pa]] != col[sa[pb]]) tmpans = max(tmpans, w[pb]);
for (IT it = se3[pa].begin(); it != se3[pa].end(); it++)
{
int pos = sa[*it];
if (pos >= 3 && pos <= l1)
{
IT last = se2[pb].lower_bound(rk2[l - pos + 3]);
if (last != se2[pb].begin()) { last--; int lst = *last; tmpans = max(tmpans, w[pb] + Query(st1, rk2[l - pos + 3], lst)); }
IT nxxt = se2[pb].upper_bound(rk2[l - pos + 3]);
if (nxxt != se2[pb].end()) { int nxt = *nxxt; tmpans = max(tmpans, w[pb] + Query(st1, rk2[l - pos + 3], nxt)); }
se1[pb].insert(rk2[l - pos + 3]);
}
if (pos >= l1 + 4 && pos <= l)
{
IT last = se1[pb].lower_bound(rk2[l - pos + 3]);
if (last != se1[pb].begin()) { last--; int lst = *last; tmpans = max(tmpans, w[pb] + Query(st1, rk2[l - pos + 3], lst)); }
IT nxxt = se1[pb].upper_bound(rk2[l - pos + 3]);
if (nxxt != se1[pb].end()) { int nxt = *nxxt; tmpans = max(tmpans, w[pb] + Query(st1, rk2[l - pos + 3], nxt)); }
se2[pb].insert(rk2[l - pos + 3]);
}
se3[pb].insert(*it);
}
siz[pb] += siz[pa]; fa[pa] = pb;
return;
}
int main()
{
scanf("%s", st + 1);
l1 = strlen(st + 1);
scanf("%s", tmp + 1);
l2 = strlen(tmp + 1);
st[l1 + 1] = '#';
for (int i = 1; i <= l2; i++) st[i + l1 + 1] = tmp[i];
l = l1 + l2 + 1;
lg[0] = -1;
for (int i = 1; i <= l / 2; i++) swap(st[i], st[l - i + 1]);
SA();
for (int i = 1; i <= l; i++) lg[i] = lg[i >> 1] + 1, rk2[i] = rk[i];
MakeST(st1);
for (int i = 1; i <= l; i++) sa[i] = height[i] = 0;
for (int i = 1; i <= l / 2; i++) swap(st[i], st[l - i + 1]);
SA();
for (int i = 1; i <= l1; i++) col[i] = 1;
for (int i = l1 + 1; i <= l; i++) col[i] = 2;
for (int i = 2; i <= l; i++) if (col[sa[i]] != col[sa[i - 1]]) ans = max(ans, height[i]);
for (int i = 2; i <= l; i++)
{
if (sa[i] <= l1 && sa[i] >= 3) se1[i].insert(rk2[l - sa[i] + 3]);
if (sa[i] >= l1 + 4 && sa[i] <= l) se2[i].insert(rk2[l - sa[i] + 3]);
se3[i].insert(i);
id[++tot] = i, siz[i] = 1, fa[i] = i;
}
sort(id + 1, id + tot + 1, Cmp);
for (int i = 1; i <= tot; i++) Merge(id[i] - 1, id[i]);
ans = max(ans, tmpans + 1);
printf("%d", ans);
return 0;
}