*题目描述:
字符串是oi界常考的问题。现在给定你n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串(注意包括本身)。
*输入:
第一行两个整数n,k。接下来n行每行一个字符串。
*输出:
输出一行n个整数,第i个整数表示第i个字符串的答案。
*样例输入:
3 1
abc
a
ab
*样例输出:
6 1 3
*提示:
对于100%的数据,n,k,l<=100000
*来源:
后缀数组
*题解:
广义后缀自动机。建完广义后缀自动机后,统计一下某个节点在所有字符串中出现的次数,对于次数大于等于k的节点统计一下答案。
*代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#ifdef WIN32
#define LL "%I64d"
#else
#define LL "%lld"
#endif
#ifdef CT
#define debug(...) printf(__VA_ARGS__)
#define setfile()
#else
#define debug(...)
#define filename ""
#define setfile() freopen(filename".in", "r", stdin); freopen(filename".out", "w", stdout);
#endif
#define R register
#define getc() (S == T && (T = (S = B) + fread(B, 1, 1 << 15, stdin), S == T) ? EOF : *S++)
#define dmax(_a, _b) ((_a) > (_b) ? (_a) : (_b))
#define dmin(_a, _b) ((_a) < (_b) ? (_a) : (_b))
#define cmax(_a, _b) (_a < (_b) ? _a = (_b) : 0)
#define cmin(_a, _b) (_a > (_b) ? _a = (_b) : 0)
char B[1 << 15], *S = B, *T = B;
inline int FastIn()
{
R char ch; R int cnt = 0; R bool minus = 0;
while (ch = getc(), (ch < '0' || ch > '9') && ch != '-') ;
ch == '-' ? minus = 1 : cnt = ch - '0';
while (ch = getc(), ch >= '0' && ch <= '9') cnt = cnt * 10 + ch - '0';
return minus ? -cnt : cnt;
}
#define maxn 100010
struct sam
{
sam *next[26], *fa;
int val, last_vis, c;
bool vis;
long long sum;
}mem[maxn << 1], *tot = mem;
inline sam *extend(R sam *p, R int c)
{
if (p -> next[c])
{
R sam *q = p -> next[c];
if (q -> val == p -> val + 1)
return q;
else
{
R sam *nq = ++tot;
memcpy(nq -> next, q -> next, sizeof nq -> next);
nq -> val = p -> val + 1;
nq -> fa = q -> fa;
q -> fa = nq;
for ( ; p && p -> next[c] == q; p = p -> fa)
p -> next[c] = nq;
return nq;
}
}
R sam *np = ++tot;
np -> val = p -> val + 1;
for ( ; p && !p -> next[c]; p = p -> fa) p -> next[c] = np;
if (!p)
np -> fa = mem;
else
{
R sam *q = p -> next[c];
if (q -> val == p -> val + 1)
np -> fa = q;
else
{
R sam *nq = ++tot;
memcpy(nq -> next, q -> next, sizeof nq -> next);
nq -> val = p -> val + 1;
nq -> fa = q -> fa;
q -> fa = np -> fa = nq;
for ( ; p && p -> next[c] == q; p = p -> fa)
p -> next[c] = nq;
}
}
return np;
}
void get_ans(R sam *x)
{
if (x == mem || x -> vis) return;
x -> vis = 1; get_ans(x -> fa); x -> sum += x -> fa -> sum;
}
char str[maxn], tot_str[maxn];
int left[maxn], right[maxn];
int main()
{
// setfile();
R int n, k;
scanf("%d%d", &n, &k);
R int tot_len = 0;
for (R int i = 1; i <= n; ++i)
{
scanf("%s", str);
R sam* x = mem;
R int len = strlen(str);
left[i] = tot_len;
right[i] = tot_len = len + tot_len - 1; ++tot_len;
memcpy(tot_str + left[i], str, len * sizeof(char));
for (R int j = 0; j < len; ++j)
x = extend(x, str[j] - 'a');
}
for (R int i = 1; i <= n; ++i)
{
R sam *x = mem, *t;
for (R int j = left[i]; j <= right[i]; ++j)
{
x = x -> next[tot_str[j] - 'a'];
for (t = x; t && t -> last_vis != i; t = t -> fa)
t -> last_vis = i, t -> c++;
}
}
for (R sam *iter = mem + 1; iter <= tot; ++iter)
iter -> sum = iter -> c >= k ? iter -> val - iter -> fa -> val : 0;
for (R sam *iter = mem + 1; iter <= tot; ++iter)
get_ans(iter);
for (R int i = 1; i <= n; ++i)
{
R sam *x = mem; R long long ans = 0;
for (R int j = left[i]; j <= right[i]; ++j)
x = x -> next[tot_str[j] - 'a'], ans += x -> sum;
printf("%lld ", ans );
}
return 0;
}