题目大意:有一种长度为$n(nleqslant 10^{18})$的字符串,给定$m(mleqslant10^3)$种限制,即字符$c$出现的次数为$cnt$,若一个字符有多种限制,则满足任意一个即可,求这种字符串有多少个,所有的$cnt$相乘小于等于 123,答案对 12345 取模。
题解:最多$6$个限制的$cnt ot=2$,状态只需要记录这些不为$1$的限制,可以把每个限制出现次数压成一个数,构建矩阵,快速幂即可
卡点:无
C++ Code:
#include <cstdio> #include <vector> #define maxn 1010 const int mod = 12345; inline int min(int a, int b) {return a < b ? a : b;} inline int max(int a, int b) {return a > b ? a : b;} inline void up(int &a, int b) {if ((a += b) >= mod) a -= mod;} inline long long pw(long long base, long long p) { base %= mod; long long res = 1; for (; p; p >>= 1, base = base * base % mod) if (p & 1) res = res * base % mod; return res; } int __sz = 1; struct matrix { #define __M 150 #define M __sz int s[__M][__M]; inline matrix() { __builtin_memset(s, 0, sizeof s); } inline friend matrix operator * (const matrix &lhs, const matrix &rhs) { matrix res; for (register int i = 0; i < M; i++) { for (register int j = 0; j < M; j++) { long long tmp = 0; for (register int k = 0; k < M; k++) tmp += static_cast<long long> (lhs.s[i][k]) * rhs.s[k][j]; res.s[i][j] = tmp % mod; } } return res; } #undef __M #undef M } BASE, RES; long long n; int m; int mp[300], ret[300], __name, prod[300]; int base[300]; std::vector<int> v[300]; inline int get(int x, int i) {return x / base[i - 1] % prod[i];} inline bool check(int x) { for (int i = 1; i <= __name; i++) { bool find = false; int now = get(x, i); for (std::vector<int>::iterator it = v[i].begin(); it != v[i].end(); it++) if (now % *it == 0) { find = true; break; } if (!find) return false; } return true; } int main() { scanf("%lld%d", &n, &m); for (int i = 1, x, ch; i <= m; i++) { char __ch; scanf("%1s%d", &__ch, &x); ch = static_cast<int>(__ch); if (!mp[ch]) mp[ch] = ++__name, ret[__name] = ch, prod[__name] = 1; prod[mp[ch]] *= x; v[mp[ch]].push_back(x); __sz *= x; } base[0] = 1; for (int i = 1; i <= __name; i++) base[i] = base[i - 1] * prod[i]; for (int i = 0; i < __sz; i++) { for (int j = 1; j <= __name; j++) { int now = get(i, j), nxt = (now + 1) % prod[j]; up(BASE.s[i][i + (nxt - now) * base[j - 1]], 1); } } RES.s[0][0] = 1; for (; n; n >>= 1, BASE = BASE * BASE) if (n & 1) RES = RES * BASE; int ans = 0; for (int i = 0; i < __sz; i++) if (check(i)) up(ans, RES.s[0][i]); printf("%d ", ans); return 0; }