[BZOJ3238][Ahoi2013]差异
试题描述
输入
一行,一个字符串S
输出
一行,一个整数,表示所求值
输入示例
cacao
输出示例
54
数据规模及约定
2<=N<=500000,S由小写英文字母组成
题解
我们考虑把两部分分开计算,“∑ len(Ti) + len(Tj)” 显然可以推出公式,现在就是算“∑ lcp(Ti, Tj)”的问题了。
我们可以先把 height 数组求出来,那么 ∑ lcp(Ti, Tj) = ∑l ∑r min{ heightl, heightl+1, ... , heightr };
接下来我们可以令 sufl = ∑r min{ heightl, heightl+1, ... , heightr },于是这个 sufl 就可以递推转移得到了,我们只需要处理出每个位置 i 后边最靠前的位置 p 满足 height[p] < height[i],那么 sufi = sufp + height[i] * (p - i)(想一想,为什么)。
那么最终 ∑ lcp(Ti, Tj) = ∑ sufi,问题解决。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> #define maxn 500010 #define maxlog 19 #define LL long long char S[maxn]; int n, rank[maxn], height[maxn], sa[maxn], Ws[maxn]; bool cmp(int* a, int p1, int p2, int l) { if(p1 + l > n && p2 + l > n) return a[p1] == a[p2]; if(p1 + l > n || p2 + l > n) return 0; return a[p1] == a[p2] && a[p1+l] == a[p2+l]; } void ssort() { int *x = rank, *y = height; int m = 0; for(int i = 1; i <= n; i++) Ws[x[i] = S[i]]++, m = std::max(m, x[i]); for(int i = 1; i <= m; i++) Ws[i] += Ws[i-1]; for(int i = n; i; i--) sa[Ws[x[i]]--] = i; for(int j = 1, pos = 0; pos < n; j <<= 1, m = pos) { pos = 0; for(int i = n - j + 1; i <= n; i++) y[++pos] = i; for(int i = 1; i <= n; i++) if(sa[i] > j) y[++pos] = sa[i] - j; for(int i = 1; i <= m; i++) Ws[i] = 0; for(int i = 1; i <= n; i++) Ws[x[i]]++; for(int i = 1; i <= m; i++) Ws[i] += Ws[i-1]; for(int i = n; i; i--) sa[Ws[x[y[i]]]--] = y[i]; std::swap(x, y); pos = 1; x[sa[1]] = 1; for(int i = 2; i <= n; i++) x[sa[i]] = cmp(y, sa[i], sa[i-1], j) ? pos : ++pos; } return ; } void calch() { for(int i = 1; i <= n; i++) rank[sa[i]] = i; for(int i = 1, j, k = 0; i <= n; height[rank[i++]] = k) for(k ? k-- : 0, j = sa[rank[i]-1]; S[j+k] == S[i+k]; k++); return ; } int mnh[maxlog][maxn], Log[maxn]; void rmq_init() { Log[1] = 0; for(int i = 2; i <= n; i++) Log[i] = Log[i>>1] + 1; for(int i = 1; i <= n; i++) mnh[0][i] = height[i]; for(int j = 1; (1 << j) <= n; j++) for(int i = 1; i + (1 << j) - 1 <= n; i++) mnh[j][i] = std::min(mnh[j-1][i], mnh[j-1][i+(1<<j-1)]); return ; } int qhei(int l, int r) { if(r > n) return -1; int t = Log[r-l+1]; return std::min(mnh[t][l], mnh[t][r-(1<<t)+1]); } LL f[maxn], ans; int main() { scanf("%s", S + 1); n = strlen(S + 1); ssort(); calch(); // for(int i = 1; i <= n; i++) printf("%d%c", height[i], i < n ? ' ' : ' '); rmq_init(); for(int i = n; i; i--) { int l = i, r = n + 2; while(r - l > 1) { int mid = l + r >> 1; // printf("[%d, %d]: %d ", i, mid, qhei(i, mid)); if(qhei(i, mid) >= height[i]) l = mid; else r = mid; } if(l <= n) l++; // printf("%d ", l); f[i] = f[l] + (LL)(l - i) * height[i]; ans += f[i]; } ans <<= 1; // LL all = ((LL)n * (n + 1) >> 1) * (n - 1); printf("%lld ", ((LL)n * (n + 1) >> 1) * (n - 1) - ans); return 0; }
上面代码中我求每个位置的后继时用的二分,下面有一个(看上去?)暴力的求法,交到大视野上发现更快了,不知道这个暴力做法靠不靠谱。。。
我们令位置 i 的后继为 nxt[i],那么暴力的做法就是检查位置 height[i+1] < height[i] 是否成立,如果成立那么显然 nxt[i] = i + 1,否则沿着 nxt 数组往下走一直碰到 height 值更小的位置。
填坑:这个做法其实就是单调栈。。。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> #define maxn 500010 #define maxlog 19 #define LL long long char S[maxn]; int n, rank[maxn], height[maxn], sa[maxn], Ws[maxn]; bool cmp(int* a, int p1, int p2, int l) { if(p1 + l > n && p2 + l > n) return a[p1] == a[p2]; if(p1 + l > n || p2 + l > n) return 0; return a[p1] == a[p2] && a[p1+l] == a[p2+l]; } void ssort() { int *x = rank, *y = height; int m = 0; for(int i = 1; i <= n; i++) Ws[x[i] = S[i]]++, m = std::max(m, x[i]); for(int i = 1; i <= m; i++) Ws[i] += Ws[i-1]; for(int i = n; i; i--) sa[Ws[x[i]]--] = i; for(int j = 1, pos = 0; pos < n; j <<= 1, m = pos) { pos = 0; for(int i = n - j + 1; i <= n; i++) y[++pos] = i; for(int i = 1; i <= n; i++) if(sa[i] > j) y[++pos] = sa[i] - j; for(int i = 1; i <= m; i++) Ws[i] = 0; for(int i = 1; i <= n; i++) Ws[x[i]]++; for(int i = 1; i <= m; i++) Ws[i] += Ws[i-1]; for(int i = n; i; i--) sa[Ws[x[y[i]]]--] = y[i]; std::swap(x, y); pos = 1; x[sa[1]] = 1; for(int i = 2; i <= n; i++) x[sa[i]] = cmp(y, sa[i], sa[i-1], j) ? pos : ++pos; } return ; } void calch() { for(int i = 1; i <= n; i++) rank[sa[i]] = i; for(int i = 1, j, k = 0; i <= n; height[rank[i++]] = k) for(k ? k-- : 0, j = sa[rank[i]-1]; S[j+k] == S[i+k]; k++); return ; } int nxt[maxn]; LL f[maxn], ans; int main() { scanf("%s", S + 1); n = strlen(S + 1); ssort(); calch(); // for(int i = 1; i <= n; i++) printf("%d%c", height[i], i < n ? ' ' : ' '); nxt[n+1] = n + 1; height[n+1] = -1; for(int i = n; i; i--) { if(height[i+1] < height[i]) nxt[i] = i + 1; else { nxt[i] = nxt[i+1]; while(height[nxt[i]] >= height[i]) nxt[i] = nxt[nxt[i]]; } f[i] = f[nxt[i]] + (LL)(nxt[i] - i) * height[i]; ans += f[i]; } ans <<= 1; // LL all = ((LL)n * (n + 1) >> 1) * (n - 1); printf("%lld ", ((LL)n * (n + 1) >> 1) * (n - 1) - ans); return 0; }
求 hack 啊。。。。。