以下摘自《信息学奥赛一本通提高篇》
矩阵乘法
设A,B是两个矩阵,令C=A*B;
1.A的列数必须与B的行数相等
2.设A是n*r的矩阵,B是r*m的矩阵,那么A与B的乘积C是一个n*m的矩阵
3.C[i][j]=A[i][k]*B[k][j] (k=1~r)
方阵乘幂
我们用快速幂的思想来求方阵乘幂。
矩阵乘法的应用
1.很容易将有用的状态储存于一个矩阵中
2.通过状态矩阵与状态转移矩阵相乘可快速得到一次DP的值(注意这个DP的状态转移方程必须要是一次的递推式)
3.求矩阵相乘的结果是要做很多次的乘法,这样的效率非常低,但由于矩阵乘法满足结合律,可以先算后面的转移矩阵,即用快速幂,迅速处理好后面的转移矩阵,再用初始矩阵乘上后面的转移矩阵得到结果,时间复杂度为log(n)级别的
做题!
1 #include<iostream> 2 #include<cstdio> 3 #include<cmath> 4 #include<algorithm> 5 #include<cstring> 6 #define R register 7 #define go(i,a,b) for(R int i=a;i<=b;i++) 8 #define yes(i,a,b) for(R int i=a;i>=b;i--) 9 #define ll long long 10 #define db double 11 using namespace std; 12 ll n,m;int mod; 13 struct node{ll mt[3][3];}; 14 node calc(node x,node y,int a,int b,int c) 15 { 16 node z;memset(z.mt,0,sizeof(z.mt)); 17 go(i,1,a) 18 go(j,1,b) 19 go(k,1,c) 20 z.mt[i][j]=(z.mt[i][j]+x.mt[i][k]*y.mt[k][j])%mod; 21 return z; 22 } 23 node a,b,c; 24 void sol() 25 { 26 while(m) 27 { 28 if(m&1) b=calc(b,a,2,2,2); 29 m>>=1;a=calc(a,a,2,2,2); 30 } 31 } 32 int main() 33 { 34 scanf("%lld",&n);m=n-2;mod=1000000007; 35 if(n<=2){printf("1");return 0;} 36 a.mt[1][1]=1,a.mt[1][2]=1,a.mt[2][1]=1,a.mt[2][2]=0; 37 b.mt[1][1]=1,b.mt[1][2]=0,b.mt[2][1]=0,b.mt[2][2]=1; 38 c.mt[1][1]=1;c.mt[2][1]=1; 39 sol(); 40 c=calc(b,c,2,1,2); 41 printf("%lld",c.mt[1][1]); 42 return 0; 43 }
1 #include<iostream> 2 #include<cstdio> 3 #include<cmath> 4 #include<algorithm> 5 #include<cstring> 6 #define R register 7 #define go(i,a,b) for(R int i=a;i<=b;i++) 8 #define yes(i,a,b) for(R int i=a;i>=b;i--) 9 #define ll long long 10 #define db double 11 using namespace std; 12 ll n,m;int mod; 13 struct node{ll mt[4][4];}; 14 node calc(node x,node y,int a,int b,int c) 15 { 16 node z;memset(z.mt,0,sizeof(z.mt)); 17 go(i,1,a) 18 go(j,1,b) 19 go(k,1,c) 20 z.mt[i][j]=(z.mt[i][j]+x.mt[i][k]*y.mt[k][j])%mod; 21 return z; 22 } 23 node a,b,c; 24 void sol() 25 { 26 while(m) 27 { 28 if(m&1) b=calc(b,a,3,3,3); 29 m>>=1;a=calc(a,a,3,3,3); 30 } 31 } 32 int main() 33 { 34 scanf("%lld%d",&n,&mod);m=n-2; 35 if(n==1){printf("1");return 0;} 36 if(n==2){printf("2");return 0;} 37 a.mt[1][1]=1,a.mt[1][2]=1,a.mt[1][3]=1; 38 a.mt[2][1]=0,a.mt[2][2]=1,a.mt[2][3]=1; 39 a.mt[3][1]=0,a.mt[3][2]=1,a.mt[3][3]=0; 40 b.mt[1][1]=1,b.mt[1][2]=0,b.mt[1][3]=0; 41 b.mt[2][1]=0,b.mt[2][2]=1,b.mt[2][3]=0; 42 a.mt[3][1]=0,b.mt[3][2]=0,b.mt[3][3]=1; 43 c.mt[1][1]=2,c.mt[2][1]=1,c.mt[3][1]=1; 44 sol(); 45 c=calc(b,c,3,1,3); 46 printf("%lld",c.mt[1][1]); 47 return 0; 48 }