首先看第一题,一道DP+字典树的题目,具体中文题意和题解见训练指南209页。
初看这题模型还很难想,看过蓝书提示之后发现,这实际上是一个标准DP题目:通过数组来储存后缀节点的出现次数。也就是用一颗字典树从后往前搜一发。最开始觉得这种搞法怕不是要炸时间,当时算成了O(N*N)毕竟1e5的数据不搞直接上N*N的大暴力。。。后来发现,字典树根本跑不完N因为题目限制字典树最多右100层左右。
实际上这道题旧思想和模型来说很好(因为直观地想半天还真想不出来。。)但是实际实现起来很简单——撸一发字典树就好了。然而专门写一篇博客是因为自从学了刘汝佳的字典树之后就发现之前自己写的那个实在是太不优雅(使用了大量指针,还牵扯到内存回收的鬼故事),反而不如刘汝佳这种,一个类搞定一切,方便快捷,也不会因为莫名的bug调试一下午什么的。。于是来说说刘汝佳字典树的实现方式:
- 一个二维数组,cha【MAXN】【SIGMA_SIZE】用来存子节点的位置
- 一个标记数组,val【MAXN】用来储存每个节点的相关信息,比如是不是单词的结尾、第几次出现等
- 一个变量,size起到类似于栈顶指针的作用。
整体上,训练指南的字典树实现方案类似于一个大型栈,开开之后就一路往进压元素就好了。因而插入节点的时候很容易联想到入栈的过程。同时,整个字典树初始化时的常数也很小——不需要回收整棵字典树,只需要讲字典树的根节点指针置零、栈指针size置一就好;在每次增加元素的时候也只需要把当前元素的指针提前置零即可。
下面放AC代码:
#include<bits/stdc++.h> using namespace std; const long long MAXN=300233; char str[MAXN]; long long len=0; long long dp[MAXN]; const long long MOD=20071027; class AC_AUTO { public: long long cha[MAXN][26]; long long f[MAXN]; long long last[MAXN]; long long val[MAXN]; long long size; AC_AUTO() { init(); } void init() { memset(cha[0],0,sizeof(cha[0])); //避免大规模初始化浪费时间 size=1; // memset(val,0,sizeof(val)); } void insert(char *tar) { int len=strlen(tar); int u=0; for(int i=0;i<len;++i) { if(!cha[u][tar[i]-'a']) { memset(cha[size],0,sizeof(cha[size])); val[size]=0; cha[u][tar[i]-'a']=size; size++; } u=cha[u][tar[i]-'a']; }val[u]=1; } bool find(char *tar) { int l=strlen(tar); int u=0;int p1=len-l; for(int i=0;i<l;++i) { if(!cha[u][tar[i]-'a'])return false; u=cha[u][tar[i]-'a']; if(val[u]) { dp[p1]+=dp[p1+i+1]; dp[p1]%=MOD; } }return val[u]; } };AC_AUTO t1; long long kk=1; void init() { memset(dp,0,sizeof(dp)); t1.init(); len=strlen(str); long long n; cin>>n; for(int i=0;i<n;++i) { char sub[233]; cin>>sub; t1.insert(sub); } dp[len]=1; for(int i=len-1;i>=0;--i) { t1.find(str+i); } cout<<"Case "<<kk++<<": "<<dp[0]<<" "; } int main() { cin.sync_with_stdio(false); while(cin>>str)init(); return 0; }
事实上我写第一题主要是为了在第一题的基础上实现后面刘汝佳规约的AC自动机,于是上面代码的类名依然是AC_AUTO。刘汝佳规约的AC自动机首先是一颗字典树——加了失配边和后缀指针的字典树。
因而在上述字典树的基础上应当加入:
- f【MAXN】表示适配函数
- last【MAXN】表示失配函数中的最近一个单词节点(VAL【】不为零)
AC自动机在功能上应当是一个多重KMP,因而从原理上认为实现方式上应当等同于KMP——按照出现顺序向后遍历并在该过程中不断寻找失配边。于是考虑字典树情况,也应当按照层数逐渐递增的形式进行匹配,因而认为BFS很合适实现这个算法——(实现树的层次遍历),于是建立失配边的过程类似基本类似于KMP+BFS
本体有些坑在于数组尺寸的调教,如果没整好。。。就地TLE。。(不是数组越界是T。。)
另外训练指南中推荐使用map来保存字符串的出现顺序以避免重复情况,但是考虑到map直接使用【】来进行操作有比较大的常数,考虑到本身AC自动机就是一个字典树,于是强行在字典树中查询可能结果会更好。
然而。。。做了这个优化之后并没有发现实质的效率提升。。都是46毫秒。。。
#include<bits/stdc++.h> using namespace std; const long long MAXN=70*26+23003; const long long SIGMA_SIZE=30; char str[1000233]; char input[233][100]; long long cnt[233]; long long len=0,n=0; const long long MOD=20071027; map<string,int> ms; //char anss[1000233]; class AC_AUTO { public: long long cha[MAXN][SIGMA_SIZE]; long long f[MAXN]; long long last[MAXN]; long long val[MAXN]; long long size; AC_AUTO() { init(); } void init() { memset(cha[0],0,sizeof(cha[0])); //避免大规模初始化浪费时间 size=1; // memset(val,0,sizeof(val)); } void insert(char *tar,int numb) { int len=strlen(tar); int u=0; for(int i=0;i<len;++i) { if(!cha[u][tar[i]-'a']) { memset(cha[size],0,sizeof(cha[size])); val[size]=0; cha[u][tar[i]-'a']=size; size++; } u=cha[u][tar[i]-'a']; }val[u]=numb;//ms[string(tar)]=numb; } void print(int j) { if(j) { cnt[val[j]]++; print(last[j]); } } void find(char *tar) { int n=strlen(tar); int j=0; for(int i=0;i<n;++i) { int c=tar[i]-'a'; while(j&& !cha[j][c])j=f[j]; j=cha[j][c]; if(val[j])print(j); else if(last[j])print(last[j]); } } void getfail() { queue<int> q; f[0]=0; for(int c=0;c<SIGMA_SIZE;++c) { int u=cha[0][c]; if(u) { f[u]=0;q.push(u); last[u]=0; } } while(!q.empty()) { int r=q.front();q.pop(); for(int c=0;c<SIGMA_SIZE;++c) { int u=cha[r][c]; if(!u)continue; q.push(u); int v=f[r]; while(v&&!cha[v][c])v=f[v]; f[u]=cha[v][c]; last[u]= val[f[u]]? f[u]:last[f[u]]; } } } long long get(char *tar ) { int l=strlen(tar ); int u=0; for(int i=0;i<l;++i) { u=cha[u][tar[i]-'a']; } return val[u]; } };AC_AUTO a1; void init() { memset(cnt,0,sizeof(cnt)); // ms.clear(); a1.init(); for(int i=1;i<=n;++i) { scanf("%s",input[i]); a1.insert(input[i],i); } a1.getfail(); scanf("%s",str); a1.find(str); long long ans=-1; for(int i=0;i<=n;++i) { if(cnt[i]>ans)ans=cnt[i]; } printf("%lld ",ans); for(int i=1;i<=n;++i) { if(cnt[a1.get(input[i])]==ans)printf("%s ",input[i]); // else cout<<"not "<<input[i]<<ends<<cnt[ms[string(input[i])]]<<endl; } } int main() { // cin.sync_with_stdio(false); while(scanf("%lld",&n)==1&&n)init(); return 0; }