题意:给你n(n<=10000)个字符串,每个字符串的长度不超过30,可以选择两个非空前缀把它们拼起来得到一个字符串(这两个前缀可以来自同一个字符串,也可以是同一个字符串的同一个非空前缀),问得到的所有字符串中有多少个本质不同的字符串.
首先看到一堆字符串的前缀我们就可以想AC自动机,这个题意看上去只要在AC自动机上DP一下就好了,然后我看了眼题解发现确实是AC自动机上DP,然后就开始想怎么DP,想了2h才搞出来...交一发T了,减少无用状态的转移A了,15s,榜上倒数第二...(后来把指针+动态内存改为指针+静态内存,18s了...目测是大数组申请和构造函数的锅?)
其实我虽然搞出来了但是并不能讲明白
那么怎么DP呢?首先我们需要把一个可行的字符串对应为AC自动机上从根节点出发的一条路径,然后有两种情况,一种是这条路径没有沿着fail指针往回走过,那么这条路径本身也对应着某个串的一个前缀,我们在AC自动机上遍历一遍(实现时可以在求fail指针时顺便做)就可以统计所有的这种路径,判断这条路径能否拆分成两个前缀就只需看路径的终点的fail指针是否指向根节点即可(如果不指向根节点,那么fail指针指向的那个节点对应了拆分方案中后面的那个前缀,而去掉后面那个前缀之后,这条路径前面的部分必然还是一个前缀).
第二种情况就是这条路径沿着fail指针往回走过. 这种情况比上一种情况复杂.如何判断一条路径是否合法?我们可以在这条路径的开头找出一个尽量长的前缀,然后在这条路径的结尾找出一个尽量长的前缀,判断这样的两个前缀能否组成整条路径.那么开头位置的尽量长的前缀对应着从AC自动机的根节点走到这条路径第一次跳fail指针的位置.假如在第一次跳fail指针后走了x步到达了终点,那么终点的深度对应着这条路径结尾位置的最长前缀.只要终点的深度大于等于x,我们就能找到合法的拆分方案把这条路径拆成两个前缀.
那么定义状态时首先可以想到f[i][j][k]表示从第i个节点第一次跳fail指针,走了j步到达一个深度为k的节点的方案有多少,这样好像会MLE+TLE.注意到从哪个位置开始第一次跳fail指针并没有什么用,我们关注的是终点的深度,那么定义f[i][j]表示第一次跳fail指针之后走了j步到达节点i的方案有多少,我写的复杂度是O(节点总数*最大深度*字符集大小),也就是300000×30×26....减少一些无用状态的转移之后能15s跑过去也是感人....榜上大神们都跑得好快呀不知道是有复杂度更好的方法还是我的常数太丑了?不过加了滚动数组之后10s了233333
#include<cstdio> #include<cstring> #include<queue> #include<algorithm> using namespace std; const int maxn=300005; struct node{ node* ch[26],*fail; int num,depth; node(){} node(int x,int d){depth=d;memset(ch,0,sizeof(ch));fail=0;num=x;} }*root;int tot=0; char str[35]; node* pos[maxn]; void Add(char *c){ node* p=root; while(*c){ int t=*c-'a'; if(p->ch[t]==NULL){p->ch[t]=new node(++tot,p->depth+1);pos[tot]=p->ch[t];} p=p->ch[t];++c; } } long long f[2][maxn]; long long ans=0; void getfail(){ queue<node*> q;q.push(root); while(!q.empty()){ node* x=q.front();q.pop(); if(x!=root&&x->fail!=root)ans++; for(int i=0;i<26;++i){ if(x->ch[i]){ if(x==root)x->ch[i]->fail=root; else x->ch[i]->fail=x->fail->ch[i]; q.push(x->ch[i]); }else{ if(x==root)x->ch[i]=root; else x->ch[i]=x->fail->ch[i]; f[1][x->ch[i]->num]++; } } } } int main(){ int n;scanf("%d",&n);int maxlen=0; root=new node(0,0);pos[0]=root; for(int i=1;i<=n;++i){ scanf("%s",str);int len=strlen(str);if(len>maxlen)maxlen=len; Add(str); } getfail(); int flag=1; for(int k=1;k<=maxlen;++k){ for(int i=1;i<=tot;++i)f[flag^1][i]=0; for(int i=1;i<=tot;++i){ if(f[flag][i]==0||pos[i]->depth<k)continue; for(int j=0;j<26;++j){ f[flag^1][pos[i]->ch[j]->num]+=f[flag][i]; } if(k<=pos[i]->depth)ans+=f[flag][i]; //printf("f[%d][%d]==%lld ",i,k,f[i][k]); } flag^=1; } printf("%lld ",ans); return 0; }
#include<cstdio> #include<cstring> #include<queue> #include<algorithm> using namespace std; const int maxn=300005; struct node{ node* ch[26],*fail; int num,depth; node(int x,int d){depth=d;memset(ch,0,sizeof(ch));fail=0;num=x;} }*root;int tot=0; char str[35]; node* pos[maxn]; void Add(char *c){ node* p=root; while(*c){ int t=*c-'a'; if(p->ch[t]==NULL){p->ch[t]=new node(++tot,p->depth+1);pos[tot]=p->ch[t];} p=p->ch[t];++c; } } long long f[maxn][32]; long long ans=0; void getfail(){ queue<node*> q;q.push(root); while(!q.empty()){ node* x=q.front();q.pop(); if(x!=root&&x->fail!=root)ans++; for(int i=0;i<26;++i){ if(x->ch[i]){ if(x==root)x->ch[i]->fail=root; else x->ch[i]->fail=x->fail->ch[i]; q.push(x->ch[i]); }else{ if(x==root)x->ch[i]=root; else x->ch[i]=x->fail->ch[i]; f[x->ch[i]->num][1]++; } } } } int main(){ int n;scanf("%d",&n);int maxlen=0; root=new node(0,0);pos[0]=root; for(int i=1;i<=n;++i){ scanf("%s",str);int len=strlen(str);if(len>maxlen)maxlen=len; Add(str); } getfail(); // for(int i=1;i<=tot;++i){ // for(int j=0;j<26;++j){ // if(pos[i]->ch[j]->num<=i){ // f[pos[i]->ch[j]->num][1]++; // } // } // } //printf("ans==%lld ",ans); for(int k=1;k<=maxlen;++k){ for(int i=1;i<=tot;++i){ if(f[i][k]==0||pos[i]->depth<k)continue; for(int j=0;j<26;++j){ f[pos[i]->ch[j]->num][k+1]+=f[i][k]; } if(k<=pos[i]->depth)ans+=f[i][k]; //printf("f[%d][%d]==%lld ",i,k,f[i][k]); } } printf("%lld ",ans); return 0; }