题面
Description
给定一个字符串 ss 。现在问你有多少个本质不同的 ss 的子串 t=t1t2⋯tm(m>0)t=t1t2⋯tm(m>0) 使得将 tt 循环左移一位后变成的 t′=t2⋯tmt1t′=t2⋯tmt1 也是 ss 的一个子串。
Input
输入仅有一行,一个字符串 s(1≤lens≤300000)s(1≤lens≤300000) 。字符串 ss 仅包含小写字母。
Output
输出一个整数表示答案。
Sample Input
(样例输入1)
abaac
(样例输入2)
aaa
Sample Output
(样例输出1)
7
(样例输出2)
3
HINT
(样例解释)
第一组数据:符合条件的字符串 tt 有: a, b, c, aa, ab, ba, aba
第二组数据:符合条件的字符串 tt 有: a, aa, aaa
(数据范围与约定)
子任务1(10分): 1≤lens≤2001≤lens≤200
子任务2(30分): 1≤lens≤50001≤lens≤5000
子任务3(60分): 1≤lens≤300000
Solution
这题不算特别难, 但确实是一道好题.
一般而言, 后缀自动机的题目只需要用到后缀树上的连边, 但这题既用了后缀树的边, 又用了后缀自动机上的边.
题目的本质是要我们求出有多少组这样的((s, c)), 其中(s)为字符串, (c)为字符, 使得(sc)和(cs)都是原串的子串.
考虑枚举原串的每个子串(s), 再枚举每个字符(c), 则我们只需要判定(cs)和(sc)是否都是原串的子串即可.
考虑如何枚举原串的每个子串(s), 不难想到用后缀树; (cs)可以通过后缀树上记录每个节点中包含的每个字符数量以及是否有(c)这个儿子来统计; (sc)则只需要判断一个节点在后缀自动机上是否有(c)这个后继即可.
口胡了这么多, 总之, 就是用后缀树上的边来找前缀, 后缀自动机上的边来找后缀.
#include <cstdio>
#include <cstring>
typedef long long LL;
const int N = 5000, K = 47, MOD = (int)1e9 + 7;
int pw[N + 7], pwInv[N + 7];
int a[N + 7], f[N + 7][N + 7], sum[N + 7][N + 7], hsh[N + 7];
inline int getInverse(int a)
{
int res = 1;
for (int x = MOD - 2; x; x >>= 1, a = (LL)a * a % MOD) if (x & 1) res = (LL)res * a % MOD;
return res;
}
inline int getHash(int L, int R)
{
return (LL)(hsh[R] - hsh[L - 1] + MOD) * pwInv[L] % MOD;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
#endif
int n; scanf("%d
", &n);
pw[0] = 1; for (int i = 1; i <= n; ++ i) pw[i] = (LL)pw[i - 1] * K % MOD, pwInv[i] = getInverse(pw[i]);
hsh[0] = 0;
for (int i = 1; i <= n; ++ i) a[i] = getchar() - '0', hsh[i] = (hsh[i - 1] + (LL)a[i] * pw[i] % MOD) % MOD;
memset(f, 0, sizeof f); memset(sum, 0, sizeof sum);
f[0][0] = 1; for (int i = 0; i <= n; ++ i) sum[0][i] = 1;
for (int i = 1; i <= n; ++ i)
{
for (int j = 1; j <= i; ++ j)
{
if (a[i - j + 1] == 0) continue;
f[i][j] = sum[i - j][j - 1];
if (j <= i - j && getHash(i - j - j + 1, i - j) != getHash(i - j + 1, i))
{
int L = 1, R = j, p;
while (L <= R)
{
int mid = L + R >> 1;
if (getHash(i - j - j + 1, i - j - j + mid) != getHash(i - j + 1, i - j + mid)) p = mid, R = mid - 1;
else L = mid + 1;
}
if (a[i - j + p] > a[i - j - j + p]) f[i][j] = (f[i][j] + f[i - j][j]) % MOD;
}
}
sum[i][0] = 0;
for (int j = 1; j <= n; ++ j) sum[i][j] = (sum[i][j - 1] + f[i][j]) % MOD;
}
printf("%d
", sum[n][n]);
}