题意:给出多个加密的模式串,和多个待匹配的串,问每个串里出现了多少种模式串。加密方法是把每3bytes加密成按6bits一个对应成4个字符,对应方法题里给了。
分析:除了解密之外,基本是个赤裸裸的AC自动机。这题要注意有多个模式串要进自动机,所以自动机的vis数组要每次清零。
#include <cstdio> #include <queue> #include <cstring> #include <cctype> using namespace std; #define D(x) const int MAX_CHILD_NUM = 256; const int MAX_NODE_NUM = 100 * 512 + 10; const int MAX_LEN = 3000 + 10; char st[MAX_LEN]; int st2[MAX_LEN]; bool vis[MAX_NODE_NUM]; struct Trie { int next[MAX_NODE_NUM][MAX_CHILD_NUM]; int fail[MAX_NODE_NUM]; int count[MAX_NODE_NUM]; int node_cnt; int root; void init() { node_cnt = 0; root = newnode(); } int newnode() { for (int i = 0; i < MAX_CHILD_NUM; i++) next[node_cnt][i] = -1; count[node_cnt++] = 0; return node_cnt - 1; } int get_id(int a) { return a; } void insert(int buf[], int id) { int now = root; for (int i = 0; buf[i] != -1; i++) { int id = get_id(buf[i]); if (next[now][id] == -1) next[now][id] = newnode(); now = next[now][id]; } count[now]++; } void build() { queue<int>Q; fail[root] = root; for (int i = 0; i < MAX_CHILD_NUM; i++) if (next[root][i] == -1) next[root][i] = root; else { fail[next[root][i]] = root; Q.push(next[root][i]); } while (!Q.empty()) { int now = Q.front(); Q.pop(); for (int i = 0; i < MAX_CHILD_NUM; i++) if (next[now][i] == -1) next[now][i] = next[fail[now]][i]; else { fail[next[now][i]]=next[fail[now]][i]; Q.push(next[now][i]); } } } int query(int buf[]) { int now = root; int res = 0; for (int i = 0; buf[i] != -1; i++) { now = next[now][get_id(buf[i])]; int temp = now; while (temp != root && !vis[temp]) { res += count[temp]; // optimization: prevent from searching this fail chain again. //also prevent matching again. vis[temp] = true; temp = fail[temp]; } } return res; } void debug() { for(int i = 0;i < node_cnt;i++) { printf("id = %3d,fail = %3d,end = %3d,chi = [",i,fail[i],count[i]); for(int j = 0;j < MAX_CHILD_NUM;j++) printf("%2d",next[i][j]); printf("] "); } } }ac; int n, m; int get_value(char ch) { if (isupper(ch)) return ch - 'A'; if (islower(ch)) return ch - 'a' + 26; if (isdigit(ch)) return ch - '0' + 52; if (ch == '+') return 62; return 63; } void transform(char *st, int *st2) { int len = strlen(st); int len2 = len * 3 / 4; for (int i = 0; i < len; i += 4) { int a = 0; for (int j = 0; j < 4; j++) { a = (a << 6) + get_value(st[i + j]); D(printf("**%d ", a)); } for (int j = 2; j >= 0; j--) { st2[i * 3 / 4 + j] = a % (1 << 8); a >>= 8; D(printf("**%d ", st2[i * 3 / 4 + j])); } } while (st[len - 1] == '=') { len--; len2--; } st2[len2] = -1; D(puts("#")); for (int i = 0; i < len2; i++) { D(printf("%d ", st2[i])); } D(puts("")); } void input() { for (int i = 1; i <= n; i++) { scanf("%s", st); transform(st, st2); ac.insert(st2, i); } ac.build(); scanf("%d", &m); for (int i = 0; i < m; i++) { scanf("%s", st); transform(st, st2); memset(vis, 0, sizeof(vis)); printf("%d ", ac.query(st2)); } puts(""); } int main() { while (scanf("%d", &n) != EOF) { ac.init(); input(); } return 0; }