题意:给出字符串长度n(<=2000000000),给出不可以包含的序列,最多10个,每个长度最大是10。问长度为n的合法序列有多少个?序列中只可能包含ACTG四个字符。
分析:AC自动机(DFA)+矩阵快速幂
ac自动机上的等价态:
等价态即用fail指针连接的点,在行走fail指针时匹配的字符数量并没有发生变化,因此这些点可以看成是相同的匹配状态。
通常有两种方法处理等价态,第一是互为等价态的点各自记录各自的信息。匹配的时候需要遍历所有等价态以判断是否匹配成功。next指针可能为空,需要匹配时进行判断是否需要走fail指针。
第二是所有等价态中的点记录本身以及所有比它浅的点的信息总和(匹配成功的单词总数),匹配时不需要走等价态以判断匹配成功与否。next指针不为空,直接指向本应通过fail指针寻找到的那个状态。
ac自动机与矩阵:
在ac自动机上,每一个从根出发并在自动机上行走的任意长度的路径都代表了一个字符串。
把ac自动机看成一个有向图的话我们可以提取它的邻接矩阵(可达矩阵),matrix[i][j]表示i和j是否相邻。
这个矩阵的n次幂matrix^n[i][j]表示从i恰好走n步到达j的路径有几条。
那可达矩阵对等价态是怎么处理的呢?如果考虑等价态,一个状态的可到达状态实在是太多了。因此我们这里认为的可达只是用地二中方法处理等价态时,next指针直接指向的被认为可达。
本题其实就是在ac自动机上找出 从根出发的 长度为n的 不经过任何匹配成功态的 路径数量。
这个只需要用矩阵快速幂来计算即可,一个状态是否是匹配成功态的。注意判断成功态和等价态的问题。
#include <iostream> #include <cstdio> #include <cstring> #include <cstdlib> #include <queue> using namespace std; #define D(x) #define MAX_LEN 20 #define MAX_CHILD_NUM 4 #define MAX_NODE_NUM 100005 int st_num, len; int matrix_size; int node_cnt; struct node { node *fail; node *next[MAX_CHILD_NUM]; int count; //how many words are matched when reach this node }trie[MAX_NODE_NUM], *root = trie; void ac_init() { memset(trie, 0, sizeof(trie)); node_cnt = 1; } int get_id(char ch) { if (ch == 'A') return 0; if (ch == 'C') return 1; if (ch == 'T') return 2; return 3; } void insert(node *root, char *str) { node *p = root; int index; for (int i = 0; str[i]; i++) { index = get_id(str[i]); if (p->next[index] == NULL) { p->next[index] = trie + node_cnt; node_cnt++; } p = p->next[index]; } p->count++; } void build_ac_automation(node *root) { queue<node*> q; int i; root->fail = NULL; q.push(root); while (!q.empty()) { node *temp = q.front(); q.pop(); node *p = NULL; for (i = 0; i < MAX_CHILD_NUM; i++) { p = temp->fail; while (p != NULL && p->next[i] == NULL) p = p->fail; if (temp->next[i] != NULL) { if (p == NULL) temp->next[i]->fail = root; else { temp->next[i]->fail = p->next[i]; temp->next[i]->count += p->next[i]->count; } q.push(temp->next[i]); }else { if (p == NULL) temp->next[i] = root; else temp->next[i] = p->next[i]; } } } } int query(node *root, char* str) { int cnt = 0, index; node *p = root; for (int i = 0; str[i]; i++) { index = get_id(str[i]); p = p->next[index]; p = (p == NULL) ? root : p; node *temp = p; cnt += temp->count; //marks count as -1 to prevent from matching again while (temp != root && temp->count != -1) { temp->count = -1; temp = temp->fail; } } return cnt; } void input() { scanf("%d%d", &st_num, &len); for (int i = 0; i < st_num; i++) { char st[MAX_LEN]; scanf("%s", st); insert(root, st); } } #define MAX_MATRIX_SIZE 101 #define MOD 100000 struct Matrix { int order; int num[MAX_MATRIX_SIZE][MAX_MATRIX_SIZE]; Matrix() {} Matrix(int ord) { order = ord; } void init() { for (int i = 0; i < order; i++) { for (int j = 0; j < order; j++) { num[i][j] = 0; } } } void output() { for (int i = 0; i < order; i++) { for (int j = 0; j < order; j++) { printf("%d ", num[i][j]); } puts(""); } } }; Matrix operator*(Matrix ma, Matrix mb) { int ord = ma.order; Matrix numc(ord); numc.init(); int i, j, k; for (i = 0; i < ord; i++) { for (k = 0; k < ord; k++) { if (ma.num[i][k] == 0) continue; for (j = 0; j < ord; j++) { long long temp = ma.num[i][k] * (long long)mb.num[k][j]; temp %= MOD; numc.num[i][j] += temp; numc.num[i][j] %= MOD; D(printf("%d %d %d ", i, j, numc.num[i][j]);) } } } return numc; } Matrix matrix_power(Matrix ma, int x) { int ord = ma.order; Matrix numc(ord); numc.init(); for (int i = 0; i < ord; i++) { numc.num[i][i] = 1; } for (; x; x >>= 1) { if (x & 1) { numc = numc * ma; } ma = ma * ma; } return numc; } void extract_matrix(Matrix &matrix) { matrix.order = node_cnt; matrix.init(); for (int i = 0; i < node_cnt; i++) { for (int j = 0; j < MAX_CHILD_NUM; j++) { if (trie[i].next[j] == NULL) continue; int temp = trie[i].next[j] - trie; if (trie[temp].count == 0) { matrix.num[i][temp] += 1; //D(printf("%d %d ", i, temp);) } } } } int main() { ac_init(); input(); build_ac_automation(root); Matrix matrix; extract_matrix(matrix); D(matrix.output();) Matrix power = matrix_power(matrix, len); int ans = 0; for (int i = 0; i < node_cnt; i++) ans = (ans + power.num[0][i]) % MOD; D(power.output();) printf("%d ", ans); return 0; }