【BZOJ4861】[Beijing2017]魔法咒语
题意:别看BZ的题面了,去看LOJ的题面吧~
题解:显然,数据范围明显的分成了两部分:一个是L很小,每个基本词汇长度未知;一个是L很大,每个基本词汇的长度是1或2。看来只能写两份代码了。
对于L很小的,我们先将禁忌串建成一个AC自动机,然后预处理出to[i][j]表示AC自动机中的第i个节点在加入基本词汇j后会到达的节点。然后设f[i][j]表示总长度为i,匹配到第j个节点的方案数。然后DP一下就好了。
对于L很大的,我们想到矩乘,设ans[i][j]表示总长度为i,匹配到第j个节点的方案数。但是ans[i]这个矩阵由ans[i-1]和ans[i-2]两个矩阵转移过来,所以我们直接用分块矩阵的乘法,即:
#include <cstdio> #include <cstring> #include <iostream> #include <queue> using namespace std; typedef long long ll; const ll mod=1000000007; int n,m,N,M,L,mx,sum; struct mat { ll v[210][210]; mat (){memset(v,0,sizeof(v));} ll* operator [](int a){return v[a];} mat operator * (mat a) { mat ret; int i,j,k; for(i=1;i<=2*M;i++) for(j=1;j<=2*M;j++) for(k=1;k<=2*M;k++) (ret[i][j]+=v[i][k]*a[k][j])%=mod; return ret; } }ans,x; int l1[60],to[110][60]; ll f[110][110]; queue<int> q; struct node { int ch[26],fail,cnt; }p[110]; char s1[60][110],s2[60][110]; void build() { q.push(1); int i,j,k,a,u; while(!q.empty()) { u=q.front(),q.pop(); for(i=0;i<26;i++) { if(!p[u].ch[i]) { if(u==1) p[u].ch[i]=1; else p[u].ch[i]=p[p[u].fail].ch[i]; continue; } q.push(p[u].ch[i]); if(u==1) { p[p[u].ch[i]].fail=1; continue; } p[p[u].ch[i]].fail=p[p[u].fail].ch[i]; p[p[u].ch[i]].cnt|=p[p[p[u].fail].ch[i]].cnt; } } for(i=1;i<=M;i++) for(j=1;j<=n;j++) { u=i,a=strlen(s1[j]); if(p[u].cnt) to[i][j]=-1; for(k=0;k<a;k++) { u=p[u].ch[s1[j][k]-'a']; if(p[u].cnt) break; } if(k==a) to[i][j]=u; else to[i][j]=-1; } } void DP() { int i,j,k,a; f[0][1]=1; for(i=0;i<L;i++) for(j=1;j<=M;j++) for(k=1;k<=n;k++) { if(to[j][k]==-1) continue; a=strlen(s1[k]); if(a+i<=L) (f[a+i][to[j][k]]+=f[i][j])%=mod; } for(i=1;i<=M;i++) sum=(sum+f[L][i])%mod; printf("%d",sum); } void pm(int y) { while(y) { if(y&1) ans=ans*x; x=x*x,y>>=1; } } void MM() { int i,j; for(i=1;i<=M;i++) { for(j=1;j<=n;j++) { if(to[i][j]==-1) continue; if(strlen(s1[j])==1) x[i][to[i][j]]++; else x[i+M][to[i][j]]++; } x[i][i+M]++; } ans[1][1]=1; pm(L); for(i=1;i<=M;i++) sum=(sum+ans[1][i])%mod; printf("%d",sum); } int main() { scanf("%d%d%d",&n,&m,&L); int i,j,a,b,u; N=1,M=1; for(i=1;i<=n;i++) scanf("%s",s1[i]),a=strlen(s1[i]),mx=max(mx,a); for(i=1;i<=m;i++) { scanf("%s",s2[i]),a=strlen(s2[i]); for(u=1,j=0;j<a;j++) { b=s2[i][j]-'a'; if(!p[u].ch[b]) p[u].ch[b]=++M; u=p[u].ch[b]; } p[u].cnt=1; } build(); if(mx<=2) MM(); else DP(); return 0; }