[LuoguP3502] [BZOJ 2085] [POJ2010]CHO-Hamsters(KMP+最短路+矩阵快速幂)
题面
Tz养了(n)只仓鼠,他们都有英文小写的名字,现在Tz想用一个字母序列来表示他们的名字,只要他们的名字是字母序列中的一个子串就算,出现多次可以重复计算。现在Tz想好了要出现(m)个名字,请你求出最短的字母序列的长度是多少。
(n leq 200,m leq 10^9)
分析
考虑建图:
先用KMP求出任意两个串(i,j)之间的最长公共前后缀长度(w),连边((i,j,w)). 注意((i,i))的边权为第(i)个字符串的长度。那么答案就是min(从(i)出发最短的经过边为(m - 1)的路径+第(i)个字符串的长度)。(可以重复经过点和边)。注意到邻接矩阵的(k)次方就表示走(k)条边之后两点间的状态。那么我们类似floyd最短路重新定义矩阵乘法:
for(int k=1;k<=n;k++){
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
ans.dist[i][j]=min(ans.dist[i][j],p.dist[i][k]+q.dist[k][j]);
}
}
}
这样,k次方后的矩阵的第i行j列就表示开始为i,终点为(j)的最短路径。遍历一遍就可以求出答案。
注意常数优化
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#define maxn 200
#define maxl 100000
#define INF 0x3f3f3f3f3f3f3f3f
using namespace std;
typedef long long ll;
int n,m;
char a[maxn+5][maxl+5];
int len[maxn+5];
void get_nex(char *s,int n,int *nex){
nex[1]=0;
for(int i=2,j=0;i<=n;i++){
while(j&&s[j+1]!=s[i]) j=nex[j];
if(s[j+1]==s[i]) j++;
nex[i]=j;
}
}
int kmp(char *s,int n,char *t,int m){
static int nex[maxl+5];
static int f[maxl+5];
get_nex(t,m,nex);
for(int i=1,j=0;i<=n;i++){
while(j&&t[j+1]!=s[i]) j=nex[j];
if(t[j+1]==s[i]) j++;
f[i]=j;
}
if(f[n]==m) return m-nex[m];//两个相同串不能全部覆盖
else return m-f[n];
}
struct matrix{
ll dist[maxn+5][maxn+5];//邻接矩阵的k次方就表示走k-1步
//把乘法换成+,加法换成min,就能求出经过k个点的最短路
void ini(){
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++) dist[i][j]=INF;
}
}
friend matrix operator * (matrix &p,matrix &q){
matrix ans;
ans.ini();
for(int k=1;k<=n;k++){
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
ans.dist[i][j]=min(ans.dist[i][j],p.dist[i][k]+q.dist[k][j]);
}
}
}
return ans;
}
};
matrix fast_pow(matrix &x,ll k){
matrix ans=x;
k--;
while(k>0){
if(k&1) ans=ans*x;
x=x*x;
k>>=1;
}
return ans;
}
matrix d;
int main(){
// freopen("data.txt","r",stdin);
scanf("%d %d",&n,&m);
ll minl=INF;
for(int i=1;i<=n;i++){
scanf("%s",a[i]+1);
len[i]=strlen(a[i]+1);
minl=min(minl,(ll)len[i]);
}
d.ini();
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
d.dist[i][j]=kmp(a[i],len[i],a[j],len[j]);
// printf("%lld ",d.dist[i][j]);
}
// printf("
");
}
if(m==1){
printf("%lld
",minl);
return 0;
}
d=fast_pow(d,m-1);
ll ans=INF;
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
ans=min(ans,len[i]+d.dist[i][j]);
}
}
printf("%lld
",ans);
}