http://acm.hdu.edu.cn/showproblem.php?pid=5667
这题的关键是处理指数,因为最后结果是a^t这种的,主要是如何计算t。
发现t是一个递推式,t(n) = c*t(n-1)+t(n-2)+b。这样的话就可以使用矩阵快速幂进行计算了。
设列矩阵[t(n), t(n-1), 1],它可以由[t(n-1), t(n-2), 1]乘上一个3*3的矩阵得到这个矩阵为:{[c, 1, b], [1, 0, 0], [0, 0, 1]},这样指数部分就可以矩阵快速幂了。
但是如果指数不模的话,计算肯定爆了,这里需要考虑费马小定理,a^(p-1) = 1(mod p),于是指数就可以模(p-1)了。
最后算出指数后,再来一次快速幂即可。
但是打这场BC的时候,我并没有考虑到a%p = 0的情况。。。最终错失这题,只过了三题。
代码:
#include <iostream> #include <cstdio> #include <cstdlib> #include <cmath> #include <cstring> #include <algorithm> #include <set> #include <map> #include <queue> #include <vector> #include <string> #define LL long long using namespace std; //矩阵乘法 //方阵 #define maxN 4 struct Mat { LL val[maxN][maxN], p; int len; Mat() { len = 3; } Mat operator=(const Mat& a) { len = a.len; p = a.p; for (int i = 0; i < len; ++i) for (int j = 0; j < len; ++j) val[i][j] = a.val[i][j]; return *this; } Mat operator*(const Mat& a) { Mat x; x.p = a.p; memset(x.val, 0, sizeof(x.val)); for (int i = 0; i < len; ++i) for (int j = 0; j < len; ++j) for (int k = 0; k < len; ++k) if (val[i][k] && a.val[k][j]) x.val[i][j] = (x.val[i][j] + val[i][k]*a.val[k][j]%p)%p; return x; } Mat operator^(const LL& a) { LL n = a; Mat x, p = *this; memset(x.val, 0, sizeof(x.val)); x.p = this->p; for (int i = 0; i < len; ++i) x.val[i][i] = 1; while (n) { if (n & 1) x = x * p; p = p * p; n >>= 1; } return x; } }from, mat; LL n, a, b, c, p; //快速幂m^n LL quickPow(LL x, LL n) { LL a = 1; while (n) { a *= n&1 ? x : 1; a %= p; n >>= 1 ; x *= x; x %= p; } return a; } void work() { if (a%p == 0) { if (n == 1) printf("1 "); else printf("0 "); return; } LL t, ans; if (n == 1) t = 0; else if (n == 2) t = b%(p-1); else { memset(from.val, 0, sizeof(from.val)); from.val[0][0] = c; from.val[0][1] = 1; from.val[0][2] = b; from.val[1][0] = 1; from.val[2][2] = 1; from.len = 3; from.p = p-1; mat = from^(n-2); t = (mat.val[0][0]*b%(p-1)+mat.val[0][2])%(p-1); } ans = quickPow(a, t); cout << ans << endl; } int main() { //freopen("test.in", "r", stdin); int T; scanf("%d", &T); for (int times = 1; times <= T; ++times) { cin >> n >> a >> b >> c >> p; work(); } return 0; }