[BZOJ5015][Snoi2017]礼物
试题描述
热情好客的小猴请森林中的朋友们吃饭,他的朋友被编号为 (1∼N),每个到来的朋友都会带给他一些礼物:香蕉。其中,第一个朋友会带给他 (1) 个香蕉,之后,每一个朋友到来以后,都会带给他之前所有人带来的礼物个数再加他的编号的 (K) 次方那么多个。所以,假设 (K=2),前几位朋友带来的礼物个数分别是:
(1,5,15,37,83,…)
假设 (K=3),前几位朋友带来的礼物个数分别是:
(1,9,37,111,…)
现在,小猴好奇自己到底能收到第 (N) 个朋友多少礼物,因此拜托于你了。
已知 (N,K),请输出第 (N) 个朋友送的礼物个数 (mod exttt{ }1000000007)。
输入
第一行,两个整数 (N,K)。
输出
一个整数,表示第 (N) 个朋友送的礼物个数 (mod exttt{ }1000000007)。
输入示例
4 2
输出示例
37
数据规模及约定
(100 exttt{%}) 的数据:(N leq 10^{18}),(K leq 10)。
题解
(N) 那么大,考虑矩阵快速幂。其实矩阵快速幂的题只要想清楚怎么转移一步,代码就可以顺理成章地写出来了。
令 (A_n) 表示第 (n) 个人送礼物的个数,(S_n = sum_{i=1}^n A_i),那么显然 (A_{n+1} = S_n + (n+1)^k),(S_{n+1} = S_n + A_{n+1} = 2 cdot S_n + (n+1)^k)。现在最关键就是由 (n^k) 转移到 ((n+1)^k),由二项式定理我们知道把它展开后由高次到底次排,系数是杨辉三角的第 (k) 行;所以我们在矩阵中分别维护 (1, n, n^2, n^3, cdots , n^k) 就好了,转移矩阵就是一个杨辉三角再加上一些其他细节,读者不妨自己完善。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define LL long long
const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
if(Head == Tail) {
int l = fread(buffer, 1, BufferSize, stdin);
Tail = (Head = buffer) + l;
}
return *Head++;
}
LL read() {
LL x = 0, f = 1; char c = Getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
return x * f;
}
#define maxn 15
#define MOD 1000000007
struct Matrix {
int n, m, A[maxn][maxn];
Matrix() {}
Matrix(int _, int __): n(_), m(__) { memset(A, 0, sizeof(A)); }
Matrix operator * (const Matrix& t) const {
Matrix ans(t.n, m);
for(int i = 1; i <= ans.n; i++)
for(int j = 1; j <= ans.m; j++)
for(int k = 1; k <= n; k++) {
ans.A[i][j] += (LL)t.A[i][k] * A[k][j] % MOD;
if(ans.A[i][j] >= MOD) ans.A[i][j] -= MOD;
}
return ans;
}
Matrix operator *= (const Matrix& t) { *this = *this * t; return *this; }
} base, tr;
Matrix Pow(Matrix a, LL b) {
Matrix ans = a, t = a; b--;
while(b) {
if(b & 1) ans *= t;
t *= t; b >>= 1;
}
return ans;
}
int main() {
LL n = read(); int k = read();
base = Matrix(k + 3, 1); base.A[1][1] = 1;
tr = Matrix(k + 3, k + 3);
for(int i = 1; i <= k + 1; i++)
for(int j = 1; j <= i; j++)
if(i == 1 || j == 1 || j == i) tr.A[i][j] = 1;
else {
tr.A[i][j] = tr.A[i-1][j-1] + tr.A[i-1][j];
if(tr.A[i][j] >= MOD) tr.A[i][j] -= MOD;
}
for(int i = 1; i <= k + 1; i++) tr.A[k+2][i] = tr.A[k+3][i] = tr.A[k+1][i];
tr.A[k+2][k+3] = 1; tr.A[k+3][k+3] = 2;
base *= Pow(tr, n);
printf("%d
", base.A[k+2][1]);
return 0;
}