这里写一篇后缀数组的题解。
首先我们考虑一个贪心,我们不断划分,则最后一个区间的最优情况一定是只有一个元素的,同理那么我们的第(i)个区间的字符个数一定比第(i+1)个区间的字符个数多(1)。
我们先考虑一个(Theta(n^2))的(dp),定义(dp_i)表示将([i,i+dp_i-1])划分一个区间,则我们可以枚举一个(j),满足(j>i)。
令(k = min(j - i , dp_j+1))同时( exttt{hash}(i , i+k-2) = exttt{hash}(j , j + k-2))或( exttt{hash}(i+1 , i+k-1)= exttt{hash}(j , j + k-2))
则(dp_i=max(dp_i,k))
我们成功实现了(n^2)的做法。
现在我们考虑打表找规律,发现一个有用的性质:
(dp_i leq dp_{i+1}+1)
因此我们可以考虑记录一个(cur),表示我们在对(i)转移时,(i)所能够达到的最大的位置。现在考虑如果(i)能够达到这个值需要满足哪些条件,即存在一个(j > cur),同时(cur-i leq dp_j),同时( exttt{hash}(i , cur-1)= exttt{hash}(j,j+cur-i-1))或者( exttt{hash}(i+1 , cur)= exttt{hash}(j,j+cur-i-1))。如果无法达到,那么我们可以退而求其次,(cur-1),进行下一轮判断。
我们考虑求一个后缀数组,那么我们就可以用二分在(height)数组上找到一个区间,使得这个区间的每一个元素与(i)的(LCP)都大于等于(cur-i),现在我们想要知道这个区间内是否存在一个元素(j>cur)的(dp)值大于等于(cur-i),所以我们可以用线段树维护一下最大值,每次(cur-1)时,把(dp)值加入即可。
下面是代码:
#include <cmath>
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN = 5e5 + 5;
int sa[MAXN] , rk[MAXN] , x[MAXN] , y[MAXN] , t[MAXN] , cnt[MAXN] , n , height[MAXN];
char s[MAXN];
void make_suffix() {
int m = 256;
for (int i = 1; i <= n; ++i) cnt[x[i] = s[i]] ++;
for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
for (int i = 1; i <= n; ++i) sa[cnt[x[i]] --] = i;
for (int k = 1; k <= n; k <<= 1) {
int tot = 0;
for (int i = n - k + 1; i <= n; ++i) y[++tot] = i;
for (int i = 1; i <= n; ++i) if(sa[i] > k) y[++tot] = sa[i] - k;
for (int i = 1; i <= m; ++i) cnt[i] = 0;
for (int i = 1; i <= n; ++i) cnt[x[y[i]]] ++;
for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i) sa[cnt[x[y[i]]] --] = y[i];
tot = 1;
t[sa[1]] = 1;
for (int i = 2; i <= n; ++i) {
if(x[sa[i]] == x[sa[i - 1]] && x[sa[i] + k] == x[sa[i - 1] + k]) t[sa[i]] = tot;
else t[sa[i]] = ++tot;
}
if(tot >= n) break;
for (int i = 1; i <= n; ++i) x[i] = t[i];
m = tot + 1;
}
for (int i = 1; i <= n; ++i) rk[sa[i]] = i;
}
void make_height() {
int k = 0;
for (int i = 1; i <= n; ++i) {
if(k) k --;
int j = sa[rk[i] - 1];
while(s[i + k] == s[j + k]) k ++;
height[rk[i]] = k;
}
}
int st[MAXN][22] , dp[MAXN] , tr[MAXN << 2];
void update(int l , int r , int now , int x , int y) {
if(l == r) {
tr[now] = y;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) update(l , mid , now << 1 , x , y);
else update(mid + 1 , r , now << 1 | 1 , x , y);
tr[now] = max(tr[now << 1] , tr[now << 1 | 1]);
}
int get_min(int l , int r) {
if(l > r) return n;
int x = log(r - l + 1) / log(2);
return min(st[l][x] , st[r - (1 << x) + 1][x]);
}
int get_l(int x , int y) {
int l = 1 , r = x , ans = 0;
while(l <= r) {
int mid = (l + r) >> 1;
if(get_min(mid + 1 , x) >= y) r = mid - 1 , ans = mid;
else l = mid + 1;
}
return ans;
}
int get_r(int x , int y) {
int l = x , r = n , ans = 0;
while(l <= r) {
int mid = (l + r) >> 1;
if(get_min(x + 1 , mid) >= y) l = mid + 1 , ans = mid;
else r = mid - 1;
}
return ans;
}
int find(int l , int r , int now , int x , int y) {
if(l >= x && r <= y) {
return tr[now];
}
int mid = (l + r) >> 1 , res = 0;
if(x <= mid) res = max(res , find(l , mid , now << 1 , x , y));
if(y > mid) res = max(res , find(mid + 1 , r , now << 1 | 1 , x , y));
return res;
}
bool check(int x , int y) {
int l = get_l(x , y) , r = get_r(x , y);
if(find(1 , n , 1 , l , r) >= y) return false;
return true;
}
int main() {
scanf("%d" , &n);
if(n == 1) {
printf("1
");
return 0;
}
scanf("%s" , s + 1);
make_suffix();make_height();
for (int i = 1; i <= n; ++i) st[i][0] = height[i];
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
st[i][j] = min(st[i][j - 1] , st[i + (1 << (j - 1))][j - 1]);
}
}
int cur = n;dp[n] = 1;
// update(1 , n , 1 , rk[n] , dp[n]);
int ans = 0;
for (int i = n - 1; i >= 1; --i) {
while(check(rk[i] , cur - i) && check(rk[i + 1] , cur - i) && i < cur) cur -- , update(1 , n , 1 , rk[cur + 1] , dp[cur + 1]);
dp[i] = cur - i + 1;ans = max(ans , dp[i]);
}
printf("%d" , ans);
return 0;
}