http://uoj.ac/problem/131
求出后缀数组和height数组后,从大到小扫相似度进行合并,每次相当于合并两个紧挨着的区间。
合并区间可以用并查集来实现,每个区间的信息都记录在这个区间的并查集的根上,合并并查集时用一个根的信息更新另一个根的信息同时计算两个答案。
时间复杂度(O(nlog n))。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 300003;
ll ret1 = 0, ret2 = -1000000000000000001ll;
int t1[N << 1], t2[N << 1], c[N];
void st(int *x, int *y, int *sa, int n, int m) {
memset(c, 0, sizeof(int) * (m + 1));
for (int i = 0; i < n; ++i) ++c[x[i]];
for (int i = 1; i <= m; ++i) c[i] += c[i - 1];
for (int i = n - 1; i >= 0; --i) sa[--c[x[y[i]]]] = y[i];
}
void mkhz(int *r, int *sa, int n, int m) {
int *x = t1, *y = t2, *t, i, j, p;
for (i = 0; i < n; ++i) x[i] = r[i], y[i] = i;
st(x, y, sa, n, m);
for (j = 1, p = 1; j < n && p < n; j <<= 1, m = p - 1) {
for (p = 0, i = n - j; i < n; ++i) y[p++] = i;
for (i = 0; i < n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j;
st(x, y, sa, n, m);
for (t = x, x = y, y = t, x[sa[0]] = 0, p = 1, i = 1; i < n; ++i)
x[sa[i]] = y[sa[i]] == y[sa[i - 1]] && y[sa[i] + j] == y[sa[i - 1] + j] ? p - 1 : p++;
}
}
void mkh(int *r, int *sa, int *rank, int *h, int n) {
int i, j, k = 0;
for (i = 0; i < n; ++i) rank[sa[i]] = i;
for (i = 1; i < n; h[rank[i++]] = k)
for (k ? --k : k = 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; ++k);
}
char s[N];
int r[N], n, rank[N], sa[N], a[N], h[N], id[N], fa[N], num_max[N], num_min[N], sz[N];
ll ans1[N], ans2[N];
bool cmp(int x, int y) {return h[x] > h[y];}
int find(int x) {return fa[x] == x ? x : fa[x] = find(fa[x]);}
template <typename T> void check_max(T &x, T y) {if (y > x) x = y;}
template <typename T> void check_min(T &x, T y) {if (y < x) x = y;}
void merge(int x, int y) {
x = find(x); y = find(y);
ret1 += 1ll * sz[x] * sz[y];
check_max(ret2, max(1ll * num_max[x] * num_max[y], 1ll * num_min[x] * num_min[y]));
fa[x] = y; sz[y] += sz[x];
check_max(num_max[y], num_max[x]);
check_min(num_min[y], num_min[x]);
}
int main() {
scanf("%d", &n);
scanf("%s", s + 1);
for (int i = 1; i <= n; ++i)
r[i] = s[i];
mkhz(r, sa, n + 1, 300);
mkh(r, sa, rank, h, n + 1);
int ai;
for (int i = 1; i <= n; ++i) scanf("%d", &ai), a[rank[i]] = ai, id[i] = i;
stable_sort(id + 2, id + n + 1, cmp);
for (int i = 1; i <= n; ++i) fa[i] = i, sz[i] = 1, num_max[i] = num_min[i] = a[i];
int tmp = 2, start;
for (int i = n - 1; i >= 0; --i) {
for (start = tmp; h[id[tmp]] >= i && tmp <= n; ++tmp);
for (int j = start; j < tmp; ++j)
merge(id[j] - 1, id[j]);
ans1[i] = ret1; ans2[i] = ret2 != -1000000000000000001ll ? ret2 : 0;
}
for (int i = 0; i < n; ++i) printf("%lld %lld
", ans1[i], ans2[i]);
return 0;
}