dp[i][j]表示走了i步走到j结点的概率。初始值dp[0][0] = 1.当走到的结点不是单词尾结点时,才能走过去。
!end[i]&&last[i] == root时,该结点才可行。
丢掉last数组, end[i] |= end[ fail[i] ]即可。 表示i节点是某些禁止字符串的后缀。
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int N = 505; 4 int id(char c){ 5 if(c >= '0'&&c <= '9') return c-'0'; 6 if(c >= 'a'&&c <= 'z') return c-'a'+10; 7 if(c >= 'A'&&c <= 'Z') return c-'A'+36; 8 return -1; 9 } 10 struct Tire{ 11 int nex[N][62], fail[N], end[N]; 12 int root, L; 13 int newnode(){ 14 memset(nex[L], -1, sizeof(nex[L])); 15 end[L] = 0; 16 return L++; 17 } 18 void init(){ 19 L = 0; 20 root = newnode(); 21 } 22 void insert(char* s){ 23 int now = root; 24 for(int i = 0; s[i]; i++){ 25 int p = id(s[i]); 26 if(nex[now][p] == -1) 27 nex[now][p] = newnode(); 28 now = nex[now][p]; 29 } 30 end[now] = 1; 31 } 32 void build(){ 33 queue<int> Q; 34 fail[root] = root; 35 for(int i = 0; i < 62; i++){ 36 int& u = nex[root][i]; 37 if(u == -1) 38 u = root; 39 else{ 40 fail[u] = root; 41 Q.push(u); 42 } 43 } 44 while(!Q.empty()){ 45 int now = Q.front(); 46 Q.pop(); 47 for(int i = 0; i < 62; i++){ 48 int& u = nex[now][i]; 49 if(u == -1) 50 u = nex[ fail[now] ][i]; 51 else{ 52 fail[u] = nex[ fail[now] ][i]; 53 end[u] |= end[ fail[u] ]; 54 //last[u] = end[ fail[u] ]? fail[u] : last[ fail[u] ]; 55 Q.push(u); 56 } 57 } 58 } 59 } 60 }; 61 Tire ac; 62 double por[70]; 63 char s[50]; 64 double dp[105][N]; 65 int main(){ 66 int n, k, t, ca = 1; scanf("%d", &t); 67 while(t--){ 68 ac.init(); 69 scanf("%d", &k); 70 for(int i = 0; i < k; i++){ 71 scanf("%s", s); 72 ac.insert(s); 73 } 74 ac.build(); 75 memset(por, 0, sizeof(por)); 76 scanf("%d", &n); 77 char ch; 78 for(int i = 0; i < n; i++){ 79 scanf(" %c", &ch); 80 int p = id(ch); 81 scanf("%lf", por+p); 82 } 83 int l; scanf("%d", &l); 84 memset(dp, 0, sizeof(dp)); 85 dp[0][0] = 1; 86 for(int i = 0; i < l; i++){ 87 for(int j = 0; j < ac.L; j++) 88 if(dp[i][j] > 0) for(int k = 0; k < 62; k++){ 89 int ret = ac.nex[j][k]; 90 if(!ac.end[ret]) 91 dp[i+1][ret] += dp[i][j]*por[k]; 92 } 93 } 94 double ans = 0; 95 for(int i = 0; i < ac.L; i++) 96 ans += dp[l][i]; 97 printf("Case #%d: %f ", ca++, ans); 98 } 99 return 0; 100 }