*题目描述:
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
*输入:
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
*输出:
输出一个整数表示答案
*样例输入:
aabb
bbaa
*样例输出:
10
*题解:
构造广义后缀自动机,分别统计每个节点在两个串中的出现次数,然后答案就是每个节点的中间节点的个数乘上节点在两个串中出现次数之积。
*代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#ifdef WIN32
#define LL "%I64d"
#else
#define LL "%lld"
#endif
#ifdef CT
#define debug(...) printf(__VA_ARGS__)
#define setfile()
#else
#define debug(...)
#define filename ""
#define setfile() freopen(filename".in", "r", stdin); freopen(filename".out", "w", stdout);
#endif
#define R register
#define getc() (S == T && (T = (S = B) + fread(B, 1, 1 << 15, stdin), S == T) ? EOF : *S++)
#define dmax(_a, _b) ((_a) > (_b) ? (_a) : (_b))
#define dmin(_a, _b) ((_a) < (_b) ? (_a) : (_b))
#define cmax(_a, _b) (_a < (_b) ? _a = (_b) : 0)
#define cmin(_a, _b) (_a > (_b) ? _a = (_b) : 0)
char B[1 << 15], *S = B, *T = B;
inline int FastIn()
{
R char ch; R int cnt = 0; R bool minus = 0;
while (ch = getc(), (ch < '0' || ch > '9') && ch != '-') ;
ch == '-' ? minus = 1 : cnt = ch - '0';
while (ch = getc(), ch >= '0' && ch <= '9') cnt = cnt * 10 + ch - '0';
return minus ? -cnt : cnt;
}
#define maxn 400010
struct sam
{
sam *next[27], *fa;
int val, size1, size2, deg;
}mem[maxn << 1], *tot = mem;
inline sam *extend(R sam *p, R int c)
{
if (p -> next[c])
{
R sam *q = p -> next[c];
if (q -> val == p -> val + 1)
return q;
else
{
R sam *nq = ++tot;
nq -> val = p -> val + 1;
memcpy(nq -> next, q -> next, sizeof nq -> next);
nq -> fa = q -> fa;
q -> fa = nq;
for ( ; p && p -> next[c] == q; p = p -> fa)
p -> next[c] = nq;
return nq;
}
}
R sam *np = ++tot;
np -> val = p -> val + 1;
for ( ; p && !p -> next[c]; p = p -> fa)
p -> next[c] = np;
if (!p)
np -> fa = mem;
else
{
R sam *q = p -> next[c];
if (q -> val == p -> val + 1)
np -> fa = q;
else
{
R sam *nq = ++tot;
nq -> val = p -> val + 1;
memcpy(nq -> next, q -> next, sizeof nq -> next);
nq -> fa = q -> fa;
q -> fa = np -> fa = nq;
for ( ; p && p -> next[c] == q; p = p -> fa)
p -> next[c] = nq;
}
}
return np;
}
char str[maxn];
sam *q[maxn << 1];
int main()
{
// setfile();
gets(str);
R int len = strlen(str);
R sam *now = mem;
for (R int i = 0; i < len; ++i)
{
now = extend(now, str[i] - 'a');
++now -> size1;
}
gets(str);
len = strlen(str); now = mem;
for (R int i = 0; i < len; ++i)
{
now = extend(now, str[i] - 'a');
++now -> size2;
}
for (R sam *iter = mem; iter <= tot; ++iter)
if (iter -> fa)
++iter -> fa -> deg;
R int head = 0, tail = 0;
for (R sam *iter = mem + 1; iter <= tot; ++iter)
if (!iter -> deg)
q[++tail] = iter;
while (head < tail)
{
++head;
now = q[head];
if (!now -> fa) continue;
now -> fa -> size1 += now -> size1;
now -> fa -> size2 += now -> size2;
--now -> fa -> deg;
if (!now -> fa -> deg) q[++tail] = now -> fa;
}
R long long ans = 0;
for (R sam *iter = mem; iter <= tot; ++iter)
if (iter -> fa)
ans += 1ll * (iter -> val - iter -> fa -> val) * iter -> size1 * iter -> size2;
printf("%lld
", ans );
return 0;
}
/*
aabb
bbaa
*/