这题过得很艰辛啊。最早在去年网络赛时就见到这题了,当时是不会。前段时间好好学了一下后缀数组,昨天再来看这题,就知道显然是后缀数组的应用了。
我知道解这道题肯定是类似于用后缀数组求字符串所有不同子串个数的思路。但一开始想走捷径,想法是先把所有的b连起来,再连上a,并通过在a后面加上一些特定字符的方法,使在sa中a的子串都排在b的后面,这样只要倒着统计一遍就行了。代码打完后才想明白这是完全行不通的。正确思路只能是分两次做,一次是只把所有的b连起来,求出所有不同子串数sumb,第二次是把a也连上,求出所有不同子串数sumab,最后结果就是sumab - sumb。
具体实现过程也有好多需要注意的,比如,为了方便除掉包含分隔字符的子串数,需要使每个分隔字符都不一样,由于所有的b串最多有100000个,所以没法用字符类型处理了,我是用的整型数组代替。还有就是最后结果要用long long类型,等等。
/* * hdu4416/win.cpp * Created on: 2013-5-21 * Author : ben */ #include <cstdio> #include <cstdlib> #include <cstring> #include <cmath> #include <ctime> #include <iostream> #include <algorithm> #include <queue> #include <set> #include <map> #include <stack> #include <string> #include <vector> #include <deque> #include <list> #include <functional> #include <numeric> #include <cctype> using namespace std; typedef long long LL; const int MAXN = 330000; int *s; int sa[MAXN], height[MAXN], rank[MAXN]; int tmp[MAXN], top[MAXN]; int N; void makesa() { int i, j, len, na; na = (N < 256 ? 256 : N); memset(top, 0, na * sizeof(int)); for (i = 0; i < N; i++) top[rank[i] = s[i] & 0xff]++; for (i = 1; i < na; i++) top[i] += top[i - 1]; for (i = 0; i < N; i++) sa[--top[rank[i]]] = i; for (len = 1; len < N; len <<= 1) { for (i = 0; i < N; i++) { j = sa[i] - len; if (j < 0) j += N; tmp[top[rank[j]]++] = j; } sa[tmp[top[0] = 0]] = j = 0; for (i = 1; i < N; i++) { if (rank[tmp[i]] != rank[tmp[i - 1]] || rank[tmp[i] + len] != rank[tmp[i - 1] + len]) top[++j] = i; sa[tmp[i]] = j; } memcpy(rank, sa, N * sizeof(int)); memcpy(sa, tmp, N * sizeof(int)); if (j >= N - 1) break; } } void lcp() { int i, j, k; for (j = rank[height[i = k = 0] = 0]; i < N - 1; i++, k++) while (k >= 0 && s[i] != s[sa[j - 1] + k]) height[j] = (k--), j = rank[sa[j] + 1]; } int get_str(int *str) { char c; while ((c = getchar()) <= ' ') { if(c == EOF) { return -1; } } int I = 0; while (c > ' ') { str[I++] = c; c = getchar(); } str[I] = 0; return I; } int str[MAXN]; int lenarr[MAXN]; int main() { #ifndef ONLINE_JUDGE freopen("data.in", "r", stdin); #endif int T, n; int sep; LL sumab, sumb, temp; scanf("%d", &T); for(int t = 1; t <= T; t++) { memset(str, 0, sizeof(str)); //分隔符从-1开始,逐次减小 sep = -1; scanf("%d", &n); //输入a串 lenarr[0] = get_str(str); str[lenarr[0]] = sep--; //输入b串并全部连接起来,用sep分隔 int *p = str + lenarr[0] + 1; for(int i = 1; i <= n; i++) { lenarr[i] = get_str(p); p = p + lenarr[i]; *(p++) = sep--; } N = p - str; str[N - 1] = 0; //a、b连起来求后缀数组 s = str; makesa(); lcp(); //统计不同子串个数 sumab = 0; for(int i = 1; i < N; i++) { sumab += N - 1 - sa[i] - height[i]; } //减掉带#的子串数 temp = N - 1; for(int i = 0; i < n; i++) { temp -= lenarr[i]; sumab -= temp * (lenarr[i] + 1); temp--; } //单独求b的后缀数组 s = str + lenarr[0] + 1; N = N - lenarr[0] - 1; makesa(); lcp(); //统计不同子串个数 sumb = 0; for(int i = 1; i < N; i++) { sumb += N - 1 - sa[i] - height[i]; } //减掉带#的子串数 temp = N - 1; for(int i = 1; i < n; i++) { temp -= lenarr[i]; sumb -= temp * (lenarr[i] + 1); temp--; } printf("Case %d: %I64d\n", t, sumab - sumb); } return 0; }