Description
定义 (f(s,t)) 为最大的 (i) 满足 (s) 的长为 (i) 的前缀和 (t) 的长为 (i) 的后缀相等。给定 (n) 个字符串 (s_1,s_2,...,s_n),求 (sum_{i=1}^n sum_{j=1}^n f(s_i,s_j)^2)。
Solution
考虑哈希,将每个后缀存入桶中。先不考虑 (f(s,t)) 的最大,而考虑所有满足条件的和,那么暴力扫描每个前缀,设桶中有 (k) 个后缀与它匹配,那么对答案的贡献为 (kl^2)。
但由于我们要算的是最大的,这样其中会有重复,观察发现这种情况会发生,是因为我们枚举的前缀有一个非空的 Border,即对这个前缀 (s[1..i]),有 (s[1..j]=s[i-j+1..i], j<i),因此我们需要减去 (s[1..j]) 的贡献(减去 (kj^2),这里的 (k) 是本次的匹配数目),因为这些已经不是最长的了。
因此一个 Hash 加一个 Kmp 即可解决问题。
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ull unsigned long long
const int N = 1000005;
const ull bas = 131;
const int mod = 998244353;
ull basPower[N];
// 定义为 fail[i] 表示 Border(s[0..i])
vector<int> kmp(string s)
{
int n=s.length();
s+=' ';
vector<int> fail(n+1);
for(int i=1;i<=n;i++)
{
fail[i]=fail[i-1];
while(s[fail[i]]!=s[i] && fail[i]) fail[i]=fail[fail[i]-1];
if(s[fail[i]]==s[i]) ++fail[i];
}
return fail;
}
struct HashString
{
string str;
vector<ull> hash;
void presolve(string srcString)
{
str=srcString;
hash.clear();
int n=str.length();
hash.resize(n);
hash[0]=str[0];
for(int i=1;i<n;i++) hash[i]=str[i]+hash[i-1]*bas;
}
HashString()
{
}
HashString(string srcString)
{
presolve(srcString);
}
ull getHash(int l,int r)
{
return hash[r]-(l?hash[l-1]:0)*basPower[r-l+1];
}
};
signed main()
{
ios::sync_with_stdio(false);
basPower[0]=1;
for(int i=1;i<N;i++)
{
basPower[i]=basPower[i-1]*bas;
}
/*string str;
cin>>str;
vector <int> fail = kmp(str);
HashString hashString;
hashString.presolve(str);*/
int n;
cin>>n;
vector <string> strSet(n);
vector <HashString> hashstrSet(n);
map <ull,int> mp; // 每种后缀的出现次数
for(int i=0;i<n;i++)
{
cin>>strSet[i];
hashstrSet[i].presolve(strSet[i]);
int len=strSet[i].length();
for(int j=0;j<len;j++)
{
mp[hashstrSet[i].getHash(j,len-1)]++;
}
}
int ans=0;
for(int i=0;i<n;i++)
{
string &str=strSet[i];
HashString &hashstr=hashstrSet[i];
int len=str.length();
vector <int> nextArray=kmp(str);
for(int j=0;j<len;j++)
{
ans+=mp[hashstr.getHash(0,j)]*(j+1)%mod*(j+1);
if(nextArray[j]>0)
{
ans-=mp[hashstr.getHash(0,j)]*nextArray[j]%mod*nextArray[j];
}
ans%=mod;
ans+=mod;
ans%=mod;
}
}
cout<<ans<<endl;
return 0;
}