• hdu4758 hdu2825 hdu4057 AC自动机与状态压缩dp的结合


    最近做到好几道关于AC自动机与状态压缩dp的结合的题,这里总结一下。

    题目一般会给出m个字符串,m不超过10,然后求长度为len并且包含特定给出的字符串集合的字符串个数。

    以HDU 4758为例:

    把题意抽象为:给出两个字符串,且只包含两种字符 'R'、'D',现在求满足下列条件的字符串个数:

    1、字符串必须包含上述两个字符串。

    2、字符串长度为(m+n),其中包含n个'D',m个'R'。

    如果不用AC自动机来做,这道题还真没法做了,因为不管怎样都找不到正确的dp状态转移方程。

    而如果引入AC自动机,把在AC自动机上的结点当做dp的一个维度的状态,那么问题就可解了。

    dp[c][zt][i][j]:c表示当前状态的字符串对应于AC自动机上的结点,zt表示给定字符串取舍情况的压缩状态,i表示'D'的个数,j表示'R'的个数。

    那么dp[c][zt][i][j]表示当前状态字符串的个数。

    循环到dp[c][zt][i][j]时,其实dp[c][zt][i][j]已经被计算出来了,然后遍历trie树中c的所有子节点,计算它们的dp值。

    最外层循环应该是字符串长度的循环,循环次数是题目要求的字符串长度,第二层循环是trie树中的所有节点,第三层是字符串取舍状态,最后是遍历c节点的所有子节点(说是子节点,其实是对c节点的下一个字符进行遍历,需要使用fail指针)。

    c节点并不代表某个具体的字符串,它是指所有能到达c节点的字符串,dp的值就是保存这些字符串中满足条件的字符串个数。

    AC自动机的作用就是增加一个状态维度,使dp过程有足够的信息来转移状态。

    #include<cstdio>
    #include<cstring>
    #include<queue>
    using namespace std;
    const int mod = 1000000007;
    int ch[202][2],End[202],cur,fail[202],last[202];
    void get_fail() {
        int now,tmpFail,Next;
        queue<int> q;
        for(int j=0;j<2;j++) {
            if(ch[0][j]) {
                q.push(ch[0][j]);
                fail[ch[0][j]] = 0;
                last[ch[0][j]] = 0;
            }
        }
        while(!q.empty()) {
            now = q.front();q.pop();
            for(int j=0;j<2;j++) {
                if(!ch[now][j]) continue;
                Next = ch[now][j];
                q.push(Next);
                tmpFail = fail[now];
                while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
                fail[Next] = ch[tmpFail][j];
                last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
            }
        }
    }
    int dp[202][4][102][102];//dp[c][zt][i][j]
    int main() {
        int T,m,n;
        char str0[3][104];
        scanf("%d",&T);
        while(T--) {
            cur=1;
            scanf("%d%d",&m,&n);
            n++;m++;
            memset(End,0,sizeof(End));
            memset(ch,0,sizeof(ch));
            memset(last,0,sizeof(last));
            for(int i=1;i<=2;i++) {
                scanf("%s",str0[i]);
                int len = strlen(str0[i]);
                int now = 0;
                for(int j=0;j<len;j++) {
                    if(str0[i][j]=='R') str0[i][j]=1;
                    else str0[i][j]=0;
                    if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                    now = ch[now][str0[i][j]];
                }
                End[now] = i;
            }
            get_fail();
    
    
            memset(dp,0,sizeof(dp));
            dp[0][0][0][0]=1;
            for(int i=0;i<n;i++) //要特别注意这里内外循环顺序,必须把i、j循环放在外面
            for(int j=0;j<m;j++) {
                for(int c=0;c<cur;c++) {
                    for(int zt=0;zt<=3;zt++){
                        if(dp[c][zt][i][j])
                        for(int k=0;k<2;k++) {
                            if(k==0&&i==n-1) continue;
                            else if(k==1&&j==m-1) continue;
                            int now=c;
                            while(now&&!ch[now][k]) now = fail[now];
                            now = ch[now][k];
    
                            int t=0;
                            if(End[now])
                                t = t|(1<<(End[now]-1));
                            int tmp = now;
                            while(last[tmp]) {
                                t = t|(1<<(End[last[tmp]]-1));
                                tmp = last[tmp];
                            }
                            if(k==0) {
                                dp[now][zt|t][i+1][j] += dp[c][zt][i][j];
                                if(dp[now][zt|t][i+1][j]>=mod) dp[now][zt|t][i+1][j]-=mod;
                            }
                            else if(k==1) {
                                dp[now][zt|t][i][j+1] += dp[c][zt][i][j];
                                if(dp[now][zt|t][i][j+1]>=mod) dp[now][zt|t][i][j+1]-=mod;
                            }
                        }
                    }
                }
            }
            long long ans=0;
            for(int i=0;i<cur;i++) {
                ans+=dp[i][3][n-1][m-1];
                if(ans>=mod) ans-=mod;
            }
            printf("%I64d
    ",ans);
        }
    }

    注意循环的内外顺序,一般情况下,字符串长度的循环都是放在外层,也就是说,一定要先计算出长度为i的所有字符串状态,才能计算长度为i+1的所有字符串状态。

    类似的 HDU 2825 :给 m 个单词构成的集合,求至少包含 k 个单词且长度为n的字符串个数。

    #include<iostream>
    #include<algorithm>
    #include<cstring>
    #include<cstdio>
    #include<queue>
    using namespace std;
    const int mod=20090717;
    int ch[11*11][26],End[11*11],cur,fail[11*11],last[11*11];
    char str0[12][12];
    void get_fail() {
        int now,tmpFail,Next;
        queue<int> q;
        for(int j=0;j<26;j++) {
            if(ch[0][j]) {
                q.push(ch[0][j]);
                fail[ch[0][j]] = 0;
                last[ch[0][j]] = 0;
            }
        }
        while(!q.empty()) {
            now = q.front();q.pop();
            for(int j=0;j<26;j++) {
                if(!ch[now][j]) continue;
                Next = ch[now][j];
                q.push(Next);
                tmpFail = fail[now];
                while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
                fail[Next] = ch[tmpFail][j];
                last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
            }
        }
    }
    int dp[27][11*11][1055];
    int main()
    {
        int sum[1055];
        for(int I=0;I<(1<<10);I++) {
                sum[I]=0;
                int tmp=I;
                while(tmp) {
                    if(tmp&1) sum[I]++;
                    tmp>>=1;
                }
        }
        int n,m,k;
        while(scanf("%d%d%d",&n,&m,&k)!=EOF&&(n||m||k))
        {
            cur=1;
            int len[13];
            memset(End,0,sizeof(End));
            memset(ch,0,sizeof(ch));
            memset(last,0,sizeof(last));
            for(int i=1;i<=m;i++) {
                scanf("%s",str0[i]);
                len[i] = strlen(str0[i]);
                int now = 0;
                for(int j=0;j<len[i];j++) {
                    str0[i][j]-='a';
                    if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                    now = ch[now][str0[i][j]];
                    str0[i][j]+='a';
                }
                End[now] = i;
    
            }
            get_fail();
            memset(dp,0,sizeof(dp));
            dp[0][0][0]=1;
            int pre=0,zt=0;
            int ans=0;
            for(int i=0;i<n;i++) {
                for(int j=0;j<cur;j++) {
                    for(int zt=0;zt<(1<<m);zt++) {
                        if(dp[i][j][zt]) {
                        for(int c=0;c<26;c++) {
                            int now = j;
                            while(now&&!ch[now][c]) now = fail[now];
                            now = ch[now][c];
                            int t=0;
                            if(End[now])
                                t = t|(1<<(End[now]-1));
                            int tmp = now;
                            while(last[tmp]) {
                                t = t|(1<<(End[last[tmp]]-1));
                                tmp = last[tmp];
                            }
                            dp[i+1][now][zt|t] += dp[i][j][zt];
                            if(dp[i+1][now][zt|t]>=mod) dp[i+1][now][zt|t]-=mod;
                        }
                        }
                    }
                }
            }
            for(int I=0;I<(1<<m);I++) {
                if(sum[I]>=k) {
                    for(int j=0;j<cur;j++){
                        ans+=dp[n][j][I];
                        if(ans>=mod) ans-=mod;
    
                    }
                }
            }
            printf("%d
    ",ans);
        }
    }

    HDU 4057:给出一些模式串,每个串有一定的价值,现在构造一个长度为M的串,问最大的价值为多少,每个模式串最多统计一次。

    #include<cstdio>
    #include<cstring>
    #include<queue>
    using namespace std;
    int ch[11*102][4],End[11*102],cur,fail[11*102],last[11*102];
    int w[11];
    char str[102],str0[11][102];
    void get_fail()
    {
        int now,tmpFail,Next;
        queue<int> q;
        //用bfs生成fail
        //初始化队列
        for(int j=0; j<4; j++)
        {
            if(ch[0][j])
            {
                q.push(ch[0][j]);
                fail[ch[0][j]] = 0;
                last[ch[0][j]] = 0;
            }
        }
        while(!q.empty())
        {
            //从队列中拿出now
            //此时now中的fail、last已经算好了
            //下面计算的是ch[now][j]中的fail、last。
            now = q.front();
            q.pop();
            for(int j=0; j<4; j++)
            {
                if(!ch[now][j]) continue;
                Next = ch[now][j];
                q.push(Next);
                tmpFail = fail[now];
                while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
                fail[Next] = ch[tmpFail][j];
                last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
            }
        }
    }
    int dp[1029][11*102][2];
    bool vis[1029][11*102][2];
    int n,l,now,ans;
    queue<int> quezt;
    queue<int> quenow;
    queue<int> quelen;
    void bfs (int zt,int now0,int len)
    {
        //printf("%d %d %d %d
    ",zt,now0,len,dp[zt][now0][len%2]);
        //printf("%d
    ",quezt.size());
        if(len==l) ans=max(ans,dp[zt][now0][l%2]);
        if(len==l+1) return;
        for(int i=0; i<4; i++)
        {
            int now=now0,temp=0;
            while(now&&!ch[now][i]) now = fail[now];
            now = ch[now][i];
            int newzt = zt;
            if(End[now])
            {
                if(((1<<(End[now]-1))|newzt)!=newzt) temp+=w[End[now]];
                newzt = (1<<(End[now]-1))|newzt;
            }
            int tmp = now;
            while(last[tmp])
            {
                if(End[last[tmp]])
                {
                    if(((1<<(End[last[tmp]]-1))|newzt)!=newzt) temp+=w[End[last[tmp]]];
                    newzt = (1<<(End[last[tmp]]-1))|newzt;
                }
                tmp = last[tmp];
            }
            if(newzt!=zt) {
                //printf("%d
    ",temp);
                if(!vis[newzt][now][(len+1)%2]) dp[newzt][now][(len+1)%2]=dp[zt][now0][len%2]+temp;
                else dp[newzt][now][(len+1)%2]=max(dp[zt][now0][len%2]+temp,dp[newzt][now][(len+1)%2]);
            }
            else{
                if(!vis[zt][now][(len+1)%2]) dp[zt][now][(len+1)%2]=dp[zt][now0][len%2];
                else dp[zt][now][(len+1)%2]=max(dp[zt][now0][len%2],dp[zt][now][(len+1)%2]);
            }
            //dfs(newzt,now,len+1);
            if(!vis[newzt][now][(len+1)%2]) {
                quezt.push(newzt);
                quenow.push(now);
                quelen.push(len+1);
                vis[newzt][now][(len+1)%2]=true;
            }
        }
        //if(len==l) ans=max(ans,dp[zt][now0][l%2]);
    }
    int main()
    {
        while(scanf("%d%d",&n,&l)!=EOF)
        {
            memset(dp,-1,sizeof(dp));
            memset(ch,0,sizeof(ch));
            memset(End,0,sizeof(End));
            memset(last,0,sizeof(last));
            cur = 1;
            int len;
            for(int i=1; i<=n; i++)
            {
                scanf("%s%d",str0[i],&w[i]);
                //puts(str0[i]);
                len = strlen(str0[i]);
                now = 0;
                for(int j=0; j<len; j++)
                {
                    if(str0[i][j]=='A') str0[i][j]=0;
                    if(str0[i][j]=='T') str0[i][j]=1;
                    if(str0[i][j]=='G') str0[i][j]=2;
                    if(str0[i][j]=='C') str0[i][j]=3;
                    if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                    now = ch[now][str0[i][j]];
                    if(str0[i][j]==0) str0[i][j]='A';
                    if(str0[i][j]==1) str0[i][j]='T';
                    if(str0[i][j]==2) str0[i][j]='G';
                    if(str0[i][j]==3) str0[i][j]='C';
                }
                End[now] = i;
            }
            //printf("%d
    ",cur);
            get_fail();
            //printf("%d
    ",cur);
            dp[0][0][0]=0;
            quezt.push(0);
            quenow.push(0);
            quelen.push(0);
            memset(vis,false,sizeof(vis));
            vis[0][0][0]=true;
            ans=-1;
            int pre=0;
            while(!quezt.empty()) {
                //if(quelen.front()!=pre) {
                //    for(int i=0;i<1029;i++)
                //    for(int j=0;j<11*102;j++) dp[i][j][pre%2]=0;
                //    pre=quelen.front();
                //}
                bfs(quezt.front(),quenow.front(),quelen.front());
                vis[quezt.front()][quenow.front()][quelen.front()%2]=false;
                quezt.pop();quenow.pop();quelen.pop();
            }
            if(ans==-1) puts("No Rabbit after 2012!");
            else printf("%d
    ",ans);
        }
    }
  • 相关阅读:
    如何将一个用utf-8编码的文本用java程序转换成ANSI编码的文本
    【笔记】Nginx热更新相关知识
    网站性能测试工具 webbench 的安装和使用
    Windows 7环境下网站性能测试小工具 Apache Bench 和 Webbench使用和下载
    【笔记】Rancher2.1容器云平台新特性
    MinTTY终端模拟器要点
    CEBX格式的文档如何转换为PDF格式文档、DOCX文档?
    Rancher2.0与DataDog集成部署
    使用Docker方式创建3节点的Etcd集群
    NTP服务器时间同步
  • 原文地址:https://www.cnblogs.com/lastone/p/5293207.html
Copyright © 2020-2023  润新知