这题想了好一会呢。。刚开始想错了,以为用自动机预处理出k长度可以包含的合法的数的个数,然后再数位dp一下就行了,写到一半发现不对,还要处理当前走的时候是不是为合法的,这一点无法移到trie树上去判断。
之后想到应该在trie树上进行数位dp,走到第i个节点且长度为j的状态是确定的,所以可以根据trie树上的节点来进行确定状态。
dp[i][j]表示当前节点为i,数第j位时可以包含多少个合法的数。
1 #include <iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<string> 5 #include<algorithm> 6 #include<stdlib.h> 7 #include<vector> 8 #include<cmath> 9 #include<queue> 10 #include<set> 11 using namespace std; 12 #define N 2010 13 #define LL long long 14 #define INF 0xfffffff 15 const double eps = 1e-8; 16 const double pi = acos(-1.0); 17 const double inf = ~0u>>2; 18 const int child_num = 2; 19 const int mod = 1000000009; 20 int dp[210][N]; 21 char s1[210],s2[210]; 22 class AC 23 { 24 private: 25 int ch[N][child_num]; 26 int Q[N]; 27 int fail[N]; 28 int val[N]; 29 int id[127]; 30 int sz; 31 int dd[810][N]; 32 public: 33 void init() 34 { 35 fail[0] = 0; 36 id['0'] = 0;id['1'] = 1; 37 } 38 void reset() 39 { 40 memset(val,0,sizeof(val)); 41 memset(ch[0],0,sizeof(ch[0])); 42 sz=1; 43 } 44 void insert(char *a,int key) 45 { 46 int p =0 ; 47 for( ; *a ; a++) 48 { 49 int d = id[*a]; 50 if(ch[p][d]==0){ 51 memset(ch[sz],0,sizeof(ch[sz])); 52 ch[p][d] = sz++; 53 } 54 p = ch[p][d]; 55 } 56 val[p] = key; 57 } 58 void construct() 59 { 60 int i,head=0,tail = 0; 61 for(i = 0 ;i < child_num ; i++) 62 { 63 if(ch[0][i]) 64 { 65 fail[ch[0][i]] = 0; 66 Q[tail++] = ch[0][i]; 67 } 68 } 69 while(head!=tail) 70 { 71 int u = Q[head++]; 72 val[u]|=val[fail[u]]; 73 for(i =0 ;i < child_num ; i++) 74 { 75 if(ch[u][i]) 76 { 77 fail[ch[u][i]] = ch[fail[u]][i]; 78 Q[tail++] = ch[u][i]; 79 } 80 else ch[u][i] = ch[fail[u]][i]; 81 } 82 } 83 } 84 int dfs(char *s,int i,int c,int e,int k) 85 { 86 if(i==-1) 87 { 88 return 1; 89 } 90 if(!e&&~dp[i][c]) 91 { 92 return dp[i][c]; 93 } 94 int mk = e?s[i]-'0':9; 95 int ans = 0; 96 for(int j = 0; j <= mk ; j++) 97 { 98 if(!k&&j==0&&i) 99 { 100 ans = (ans+dfs(s,i-1,c,e&&j==mk,k)); 101 continue; 102 } 103 int p = c,flag = 1; 104 for(int g = 3; g >=0 ; g--) 105 { 106 int o = (j&(1<<g))?1:0; 107 p = ch[p][o]; 108 int tmp = p; 109 while(tmp!=0) 110 { 111 if(val[tmp]) 112 { 113 flag = 0; 114 break; 115 } 116 tmp = fail[tmp]; 117 } 118 if(!flag) break; 119 } 120 if(flag) 121 { 122 ans = (ans+dfs(s,i-1,p,e&&j==mk,1))%mod; 123 } 124 } 125 return e?ans:dp[i][c] = ans; 126 } 127 void work(char *s1,char *s2) 128 { 129 memset(dp,-1,sizeof(dp)); 130 printf("%d ",(dfs(s2,strlen(s2)-1,0,1,0)-dfs(s1,strlen(s1)-1,0,1,0)+mod)%mod); 131 } 132 }ac; 133 char vir[22]; 134 char ss1[210],ss2[210]; 135 int main() 136 { 137 int t,n,i; 138 ac.init(); 139 scanf("%d",&t); 140 while(t--) 141 { 142 ac.reset(); 143 scanf("%d",&n); 144 while(n--) 145 { 146 scanf("%s",vir); 147 ac.insert(vir,1); 148 } 149 ac.construct(); 150 scanf("%s%s",s1,s2); 151 int k = strlen(s1),kk= strlen(s2); 152 for(i = k-1 ; i >= 0; i--) 153 { 154 if(s1[i]>'0') 155 { 156 s1[i]-=1; 157 break; 158 } 159 else 160 s1[i] = '9'; 161 } 162 for(i = 0; i < k ; i++) 163 ss1[k-1-i] = s1[i]; 164 ss1[k] = '