• ACM之路(16)—— 数位DP


      题目就是kuangbin的数位DP。

      先讲C题,不要62,差不多就是一个模板题。要注意的是按位来的话,光一个pos是不够的,还需要一维来记录当前位置是什么数字,这样才能防止同一个pos不同数字的dp值混在一起。直接丢代码:

     1 #include <stdio.h>
     2 #include <algorithm>
     3 #include <string.h>
     4 #include <iostream>
     5 #include <vector>
     6 #include <queue>
     7 #include <math.h>
     8 using namespace std;
     9 typedef long long ll;
    10 
    11 // pos , 记录当前位是什么
    12 // 第二维是为了防止同一个pos的不同which的dp混淆在一起,因为是记忆化搜索的
    13 int bit[12],dp[12][12];
    14 
    15 int dfs(int pos,int which,bool have_six,bool flag)
    16 {
    17     if(pos == -1) return 1;
    18     int& ans = dp[pos][which];
    19     if(flag && ans!=-1) return ans;
    20     int d = flag?9:bit[pos];
    21 
    22     int ret = 0;
    23     for(int i=0;i<=d;i++)
    24     {
    25         if(i==4) continue;
    26         if(have_six && i==2) continue;
    27         ret += dfs(pos-1,i,i==6,flag||i<d);
    28     }
    29     if(flag) ans = ret;
    30     return ret;
    31 }
    32 
    33 int solve(int x)
    34 {
    35     //if(x==0) return 0;
    36     int pos = 0;
    37     while(x)
    38     {
    39         bit[pos++] = x % 10;
    40         x /= 10;
    41     }
    42 
    43     int ans = 0;
    44     ans += dfs(pos-1,0,false,false);
    45     return ans; 
    46     // 因为0也是一个值
    47     // 所以solve(5)=5是因为0.1.2.3.5
    48 }
    49 
    50 int main()
    51 {
    52     int x,y;
    53     memset(dp,-1,sizeof(dp));
    54     //printf("%d !!
    ",solve(5));
    55     while(scanf("%d%d",&x,&y)==2)
    56     {
    57         if(x==0 && y==0) break;
    58         else printf("%d
    ",solve(y)-solve(x-1));
    59     }
    60 }
    View Code

      那么如果是求区间内还有62的呢?可以是在上面的基础上,用总个数减去;也可以再开一维have,表示是否拥有了62。这样变化一下就是D题了,丢代码:

     1 #include <stdio.h>
     2 #include <algorithm>
     3 #include <string.h>
     4 #include <iostream>
     5 #include <vector>
     6 #include <queue>
     7 #include <math.h>
     8 using namespace std;
     9 typedef long long ll;
    10 
    11 
    12 int bit[22];
    13 ll dp[22][2][2];
    14 
    15 /*ll dfs(int pos,bool state,bool flag)
    16 {
    17     if(pos == -1) return 1;
    18     ll& ans = dp[pos][state];
    19     if(flag && ans!=-1) return ans;
    20     int d = flag?9:bit[pos];
    21 
    22     ll ret = 0;
    23     for(int i=0;i<=d;i++)
    24     {
    25         if(state && i==9) continue;
    26         ret += dfs(pos-1,i==4,flag||i<d);
    27     }
    28     if(flag) ans = ret;
    29     return ret;
    30 }*/
    31 
    32 ll dfs(int pos,bool state,bool have,bool flag)
    33 {
    34     if(pos == -1) return have;
    35     ll& ans = dp[pos][have][state];
    36     if(flag && ans!=-1) return ans;
    37     int d = flag?9:bit[pos];
    38 
    39     ll ret = 0;
    40     for(int i=0;i<=d;i++)
    41     {
    42         if(state && i==9) ret += dfs(pos-1,i==4,1,flag||i<d);
    43         else ret += dfs(pos-1,i==4,have,flag||i<d);
    44     }
    45     if(flag) ans = ret;
    46     return ret;
    47 }
    48 
    49 ll solve(ll x)
    50 {
    51     //if(x==0) return 0;
    52     int pos = 0;
    53     while(x)
    54     {
    55         bit[pos++] = x % 10;
    56         x /= 10;
    57     }
    58 
    59     ll ans = 0;
    60     ans += dfs(pos-1,false,false,false);
    61     return ans;
    62 }
    63 
    64 int main()
    65 {
    66     int T;
    67     scanf("%d",&T);
    68     memset(dp,-1,sizeof(dp));
    69 
    70     while(T--)
    71     {
    72         ll x;
    73         scanf("%I64d",&x);
    74         //printf("%I64d
    ",x-(solve(x)-1));
    75         cout<<solve(x)<<endl;
    76     }
    77 }
    View Code

    另外还需要注意的是上面的问题并不需要记录当前一位是哪个数字,只要记录是不是需要的数字即可。比方说62,我只要用一个state保存这一位是不是6即可。

      以上就是数位dp的基本套路,但是还会遇到一种情况,比方说让你统计区间内数的二进制表示的0的个数,这样子在dfs时需要再加一个参数first来表示有没有前导0。举个例子,比如说10010,在dfs以后,如果变成了0XXXX的情况,显然第一个0是不能算在0的个数之内的。因此,如果数位dp的过程中,有无前导0会对结果造成影响的,就要再加一个参数first来辅助完成dp过程。具体见E题代码:

     1 #include <stdio.h>
     2 #include <algorithm>
     3 #include <string.h>
     4 using namespace std;
     5 
     6 int bit[40];
     7 int dp[40][40][40];
     8 //int len;
     9 
    10 /*
    11     单纯的用len来记录总位数是不行的,因为比方说5是101,
    12     比它小的数可能len不等于3
    13     也就是说在递推的过程中len是会发生变化的
    14     
    15     因此,换一个方法,用一个bool变量first记录之前位是否存在
    16 */
    17 
    18 int dfs(int pos,int num_zero,int num_one,bool flag,bool first)
    19 {
    20     if(pos == -1) return num_zero >= num_one;
    21     int& ans = dp[pos][num_zero][num_one];
    22     if(flag && ans!=-1) return ans;
    23     
    24     int d = flag?1:bit[pos];
    25     int ret = 0;
    26     for(int i=0;i<=d;i++)
    27     {
    28         ret += dfs(pos-1,(first||i?num_zero+(i==0):0),(first||i?num_one+(i==1):0),flag||i<d,first||i);
    29     }
    30     if(flag) ans = ret;
    31     return ret;
    32 }
    33 
    34 int solve(int x)
    35 {
    36     int pos = 0;
    37     while(x)
    38     {
    39         bit[pos++] = x%2;
    40         x/=2;
    41     }
    42     
    43     //len = pos;
    44     return dfs(pos-1,0,0,false,false);
    45 }
    46 
    47 int main()
    48 {
    49     int x,y;
    50     memset(dp,-1,sizeof(dp));
    51     while(scanf("%d%d",&x,&y)==2)
    52     {
    53         printf("%d
    ",solve(y)-solve(x-1));
    54     }
    55     return 0;
    56 }
    View Code

      然后几题是比较有意思的。

      A题,问区间内的美丽数的个数,美丽数指的是:这个数能被各个位置上的数字整除。这题的方法是直接找他们的最小公倍数,然后当做状态的状态的一维即可。要注意的是如果最小公倍数开最大的话,会超内存,那么只要将这一系列的最小公倍数离散化即可。其实这题似乎还可以优化,我没有深究了。具体见代码:

     1 #include <stdio.h>
     2 #include <algorithm>
     3 #include <string.h>
     4 #include <iostream>
     5 #include <vector>
     6 #include <queue>
     7 #include <math.h>
     8 using namespace std;
     9 typedef long long ll;
    10 
    11 int gcd(int a,int b) {return a%b?gcd(b,a%b):b;}
    12 int lcm(int a,int b) {return a*b/gcd(a,b);}
    13 // pos,离散化后的公约数的位置,各个位置上数字的和
    14 int a[2520+10];
    15 int bit[22];
    16 ll dp[22][50][2520+100];
    17 
    18 ll dfs(int pos,int _lcm,int sum,bool flag)
    19 {
    20     if(pos == -1) return (ll)(sum%_lcm==0);
    21     ll& ans = dp[pos][a[_lcm]][sum];
    22     if(flag && ans!=-1) return ans;
    23     int d = flag?9:bit[pos];
    24 
    25     ll ret = 0;
    26     for(int i=0;i<=d;i++)
    27     {
    28         ret += dfs(pos-1,i?lcm(_lcm,i):_lcm,(sum*10+i)%2520,flag||i<d);
    29     }
    30     if(flag) ans = ret;
    31     return ret;
    32 }
    33 
    34 ll solve(ll x)
    35 {
    36     int pos = 0;
    37     while(x)
    38     {
    39         bit[pos++] = x % 10;
    40         x /= 10;
    41     }
    42 
    43     ll ans = 0;
    44     ans += dfs(pos-1,1,0,false);
    45     return ans;
    46 }
    47 
    48 int main()
    49 {
    50     // 离散化
    51     for(int i=1,j=0;i<=2520;i++)
    52     {
    53         a[i] = 2520%i?0:(++j);
    54         //printf("%d !!
    ",j );
    55     } // max = 48
    56     
    57     int T;
    58     scanf("%d",&T);
    59     memset(dp,-1,sizeof(dp));
    60 
    61     while(T--)
    62     {
    63         ll x,y;
    64         scanf("%I64d%I64d",&x,&y);
    65         printf("%I64d
    ",solve(y)-solve(x-1));
    66     }
    67 }
    View Code

      B题,问区间内满足以下条件的数,这个数字内的LIS等于K。那么只要模拟LIS用二进制储存一个状态码当做一维的内容即可。具体见代码:

     1 #include <stdio.h>
     2 #include <algorithm>
     3 #include <string.h>
     4 #include <vector>
     5 #include <map>
     6 #include <set>
     7 #include <queue>
     8 #include <iostream>
     9 #include <stdlib.h>
    10 #include <string>
    11 #include <stack>
    12 using namespace std;
    13 const int inf = 0x3f3f3f3f;
    14 typedef long long ll;
    15 typedef pair<int,int> pii;
    16 const int N = 500 + 5;
    17 
    18 int bit[22],K;
    19 ll dp[22][1<<10][12];
    20 ll L,R;
    21 
    22 /*
    23     state是一个状态码,
    24     二进制状态下,各位如果是0表示LIS中没有这个数,
    25     否则,就有。
    26     例如100110,那么表示当前这个数的LIS为125
    27     如果我要插入一个数字是3,那么3将5替换掉就可以了,
    28     否则,如果插入的是一个6,比5都大,只要将6放在5后面即可。
    29     下面的代码就是实现的这一过程。
    30 */
    31 int newState(int state,int x)
    32 {
    33     // 找到第一个大于x的数将它替换掉
    34     // 是对LIS的nlog(n)的体现
    35     for(int i=x;i<10;i++)
    36     {
    37         if(state & (1<<i)) return (state^(1<<i))|(1<<x);
    38     }
    39     return state | (1<<x);
    40 }
    41 
    42 int getLen(int state)
    43 {
    44     int cnt = 0;
    45     while(state)
    46     {
    47         if(state&1) cnt++;
    48         state >>= 1;
    49     }
    50     return cnt;
    51 }
    52 
    53 ll dfs(int pos,int state,bool first,bool flag)
    54 {
    55     if(pos == -1) return getLen(state)==K;
    56     ll& ans = dp[pos][state][K];
    57     if(flag && ans!=-1) return ans;
    58 
    59     int d = flag?9:bit[pos];
    60     ll ret = 0;
    61     for(int i=0;i<=d;i++)
    62     {
    63         ret += dfs(pos-1,first||i?newState(state,i):0,first||i,flag||i<d);
    64     }
    65     if(flag) ans = ret;
    66     return ret;
    67 }
    68 
    69 ll solve(ll x)
    70 {
    71     int pos = 0;
    72     while(x)
    73     {
    74         bit[pos++] = x % 10;
    75         x /= 10;
    76     }
    77 
    78     return dfs(pos-1,0,false,false);
    79 }
    80 
    81 int main()
    82 {
    83     int T;
    84     scanf("%d",&T);
    85     memset(dp,-1,sizeof(dp));
    86     for(int kase=1;kase<=T;kase++)
    87     {
    88         scanf("%I64d%I64d%d",&L,&R,&K);
    89         printf("Case #%d: %I64d
    ",kase,solve(R)-solve(L-1));
    90     }
    91 }
    View Code

      讲到状态码形式的数位dp,可以再看看最后一题。大意是找出区间内平衡数:任何奇数只要出现了,就必须出现偶数次;任何偶数,只要出现了就必须出现奇数次。不出现的数字不做讨论,同时是任意一个奇数都要出现偶数次,比方说1333,1和3它们都出现了奇数次,而不是所有奇数出现的次数和是偶数次,因此这个数字不满足平衡数的要求。具体实现的话也是转化成3进制的状态码即可,具体见代码:

      1 #include <stdio.h>
      2 #include <algorithm>
      3 #include <string.h>
      4 #include <vector>
      5 #include <map>
      6 #include <set>
      7 #include <queue>
      8 #include <iostream>
      9 #include <stdlib.h>
     10 #include <string>
     11 #include <stack>
     12 using namespace std;
     13 const int inf = 0x3f3f3f3f;
     14 typedef long long ll;
     15 typedef pair<int,int> pii;
     16 const int N = 500 + 5;
     17 
     18 int bit[22];
     19 ll dp[22][60000];
     20 int pw[12];
     21 int f[60000][10];
     22 // f表示的是在这个状态码下各个数字的出现次数
     23 // 是预处理好的
     24 // 比如f[12345][6]表示在12345这个状态码下的6的出现的次数的奇偶
     25 
     26 bool check(int state)
     27 {
     28     int x = 1; // 1表示出现奇数次,2表示出现偶数次
     29     for(int i=9;i>=0;i--)
     30     {
     31         if(f[state][i])
     32         {
     33             if(f[state][i] + x != 3) return false;
     34         }
     35         x = 3 - x;
     36     }
     37     return true;
     38 }
     39 
     40 int newstate(int state,int i)
     41 {
     42     if(f[state][i] <= 1)
     43     {
     44         //f[state][i]++;
     45         return state + pw[i]; // 如果出现了0次或者奇数次,就次数加1,相当于那个数字的那一位的三进制加1
     46     }
     47     //f[state][i]--;
     48     return state - pw[i]; // 如果出现了偶数次,让状态码对应位置的三进制减1,表示变成了出现奇数次
     49 }
     50 
     51 ll dfs(int pos,int state,bool first,bool flag)
     52 {
     53     if(pos == -1) return check(state);
     54     ll& ans = dp[pos][state];
     55     if(flag && ans!=-1) return ans;
     56 
     57     int d = flag?9:bit[pos];
     58     ll ret = 0;
     59     for(int i=0;i<=d;i++)
     60     {
     61         ret += dfs(pos-1,first||i?newstate(state,i):0,first||i,flag||i<d);
     62     }
     63     if(flag) ans = ret;
     64     return ret;
     65 }
     66 
     67 ll solve(ll x)
     68 {
     69     int pos = 0;
     70     while(x)
     71     {
     72         bit[pos++] = x % 10;
     73         x /= 10;
     74     }
     75 
     76     return dfs(pos-1,0,false,false);
     77 }
     78 
     79 void init()
     80 {
     81     memset(dp,-1,sizeof(dp));
     82     memset(f,0,sizeof(f));
     83     pw[0] = 1;
     84     for(int i=1;i<=10;i++) pw[i] = 3*pw[i-1];
     85         
     86     // 下面是预处理出f的值
     87     for(int i=1;i<=pw[10]-1;i++)
     88     {
     89         int now = i;
     90         for(int j=9;j>=0;j--)
     91         {
     92             if(now >= pw[j])
     93             {
     94                 int t;
     95                 f[i][j] = t = now/pw[j];
     96                 now -= t*pw[j];
     97             }
     98         }
     99     }
    100 }
    101 
    102 int main()
    103 {
    104     int T;
    105     scanf("%d",&T);
    106     init();
    107     while(T--)
    108     {
    109         ll l,r;
    110         cin >> l >> r;
    111         cout << solve(r)-solve(l-1) << endl;
    112     }
    113 }
    View Code

      最后想讲的是J题,求的是区间内的与7无关的数字的平方和。因为涉及到平方,所以涉及到展开的问题。具体见代码注释:

     1 #include <stdio.h>
     2 #include <algorithm>
     3 #include <string.h>
     4 #include <vector>
     5 #include <map>
     6 #include <set>
     7 #include <queue>
     8 #include <iostream>
     9 #include <stdlib.h>
    10 #include <string>
    11 #include <stack>
    12 using namespace std;
    13 const int inf = 0x3f3f3f3f;
    14 const int mod = (int)1e9 + 7;
    15 typedef long long ll;
    16 typedef pair<int,int> pii;
    17 const int N = 500 + 5;
    18 
    19 struct node
    20 {
    21     ll n,s,sq;
    22 }dp[22][10][10];
    23 int bit[22];
    24 ll L,R,pw[22];
    25 
    26 node dfs(int pos,int sum,int digit_sum,bool flag)
    27 {
    28     if(pos == -1) return (node){sum&&digit_sum,0,0};
    29     node& ans = dp[pos][sum][digit_sum];
    30     if(flag && ans.n!=-1) return ans;
    31 
    32     int d = flag?9:bit[pos];
    33     node ret = (node){0,0,0};
    34     for(int i=0;i<=d;i++)
    35     {
    36         if(i == 7) continue;
    37         node temp = dfs(pos-1,(sum*10+i)%7,(digit_sum+i)%7,flag||i<d);
    38         ret.n = (ret.n + temp.n) % mod;
    39 
    40         // 别忘了乘以个数n
    41         ret.s += (temp.s + (pw[pos] * i) % mod * temp.n) % mod;
    42         ret.s %= mod;
    43 
    44         // 到这一位需要增加的sq是(i*pw[i]+temp.s)^2,拆开累加即可
    45         // temp.sq 是 temp.s 的平方
    46         // 在处理(i*pw[i])的平方时需要乘以个数n。
    47         // 而处理2倍它们的乘积时不用是因为temp.s中已经乘过n了
    48         ret.sq += (temp.sq + 2 * pw[pos] * i % mod * temp.s % mod) % mod;
    49         ret.sq %= mod;
    50         ret.sq += (temp.n * pw[pos] % mod * ((pw[pos]*i*i) % mod)) % mod;
    51         ret.sq %= mod;
    52     }
    53     if(flag) ans = ret;
    54     return ret;
    55 }
    56 
    57 ll solve(ll x)
    58 {
    59     int pos = 0;
    60     while(x)
    61     {
    62         bit[pos++] = x % 10;
    63         x /= 10;
    64     }
    65 
    66     return dfs(pos-1,0,0,false).sq % mod;
    67 }
    68 
    69 int main()
    70 {
    71     int T;
    72     scanf("%d",&T);
    73     memset(dp,-1,sizeof(dp));
    74     pw[0] = 1;
    75     for(int i=1;i<22;i++) pw[i] = (pw[i-1] * 10) % mod;
    76     while(T--)
    77     {
    78         scanf("%I64d%I64d",&L,&R);
    79         // 对mod以后的别忘记先加mod再mod不然可能是负数
    80         printf("%I64d
    ",(solve(R)-solve(L-1) + mod) % mod);
    81     }
    82 }
    View Code

      最后想补充的一点是,很多的solve求的都是0~指定位置满足条件的和,但是没有关系,只要相减以后就都抵消了,但是单单使用solve的话就可能会出现问题。

  • 相关阅读:
    ParksLink修改密码
    ORA-01940:无法删除当前已链接的用户
    imp导入数据的时候报错:ORA-01658: 无法为表空间 MAXDATA 中的段创建 INITIAL 区
    Linux下查看日志用到的常用命令
    大批量数据高效插入数据库表
    线程中断:Thread类中interrupt()、interrupted()和 isInterrupted()方法详解
    CyclicBarrier、CountDownLatch、Callable、FutureTask、thread.join() 、wait()、notify()、Condition
    Mysql全文索引
    Docker 镜像的常用操作
    Docker 入门
  • 原文地址:https://www.cnblogs.com/zzyDS/p/5684209.html
Copyright © 2020-2023  润新知