在 n >= 4时候直接暴力。
n <= 4的时候二分加矩阵快速幂去check
#include<bits/stdc++.h> #define LL long long #define LD long double #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ull unsigned long long using namespace std; const int N = 2e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const double eps = 1e-8; const double PI = acos(-1); const int MN = 3; struct Matrix { LL a[MN][MN]; Matrix() { memset(a, 0, sizeof(a)); } void init() { for(int i = 0; i < MN; i++) a[i][i] = 1; } Matrix operator * (const Matrix &B) const { Matrix C; for(int i = 0; i < MN; i++) for(int j = 0; j < MN; j++) for(int k = 0; k < MN; k++) { C.a[i][j] += 1.0 * a[i][k] * B.a[k][j] <= 2e18 ? a[i][k] * B.a[k][j] : INF; C.a[i][j] = min(C.a[i][j], INF); } return C; } Matrix operator ^ (LL b) { Matrix C; C.init(); Matrix A = (*this); while(b) { if(b & 1) C = C * A; A = A * A; b >>= 1; } return C; } }; int n; LL k; LL a[N], b[N]; Matrix mat; int Mat[3][3] = { {1, 0, 0}, {1, 1, 0}, {1, 1, 1} }; int main() { for(int i = 0; i < 3; i++) for(int j = 0; j < 3; j++) mat.a[i][j] = Mat[i][j]; scanf("%d%lld", &n, &k); for(int i = 0; i < n; i++) scanf("%d", &a[i]); int m = n; n = 0; for(int i = 0, flag = 0; i < m; i++) { if(a[i]) flag = 1; if(flag) a[n++] = a[i]; } LL mx = 0; for(int i = 0; i < n; i++) mx = max(mx, a[i]); if(mx >= k) { puts("0"); return 0; } if(n >= 4) { for(int o = 1; o <= 10000000 && mx < k; o++) { LL prefix = 0; for(int i = 0; i < n && mx < k; i++) b[i] = i ? b[i - 1] + a[i] : a[i]; for(int i = 0; i < n && mx < k; i++) a[i] = b[i], mx = max(mx, a[i]); if(mx >= k) { printf("%d ", o); return 0; } } } else { for(int i = n - 1; i >= 0; i--) a[3 - (n - i)] = a[i]; for(int i = 0; i < (3 - n); i++) a[i] = 0; LL low = 1, high = (LL)1e18, ans = high; while(low <= high) { LL mid = (low + high) >> 1; Matrix tmp = mat ^ mid; LL val = 0; for(int i = 0; i < 3; i++) { val += 1.0 * a[i] * tmp.a[2][i] <= 2e18 ? a[i] * tmp.a[2][i] : INF; val = min(val, INF); } if(val >= k) high = mid - 1, ans = mid; else low = mid + 1; } printf("%lld ", ans); } return 0; } /* */