两矩阵相乘,朴素算法的复杂度是O(N^3)。如果求一次矩阵的M次幂,按朴素的写法就是O(N^3*M)。既然是求幂,不免想到快速幂取模的算法,这里有快速幂取模的介绍,a^b %m 的复杂度可以降到O(logb)。如果矩阵相乘是不是也可以实现O(N^3 * logM)的时间复杂度呢?答案是肯定的。
先定义矩阵数据结构:
struct Mat {
double mat[N][N];
};
O(N^3)实现一次矩阵乘法
Mat operator * (Mat a, Mat b) {
Mat c;
memset(c.mat, 0, sizeof(c.mat));
int i, j, k;
for(k = 0; k < n; ++k) {
for(i = 0; i < n; ++i) {
if(a.mat[i][k] <= 0) continue; //不要小看这里的剪枝,cpu运算乘法的效率并不是想像的那么理想(加法的运算效率高于乘法,比如Strassen矩阵乘法)
for(j = 0; j < n; ++j) {
if(b.mat[k][j] <= 0) continue; //剪枝
c.mat[i][j] += a.mat[i][k] * b.mat[k][j];
}
}
}
return c;
}
下面介绍一种特殊的矩阵:单位矩阵
很明显的可以推知,任何矩阵乘以单位矩阵,其值不改变。
有了前边的介绍,就可以实现矩阵的快速连乘了。
Mat operator ^ (Mat a, int k) {
Mat c;
int i, j;
for(i = 0; i < n; ++i)
for(j = 0; j < n; ++j)
c.mat[i][j] = (i == j); //初始化为单位矩阵
for(; k; k >>= 1) {
if(k&1) c = c*a;
a = a*a;
}
return c;
}
举个例子:
求第n个Fibonacci数模M的值。如果这个n非常大的话,普通的递推时间复杂度为O(n),这样的复杂度很有可能会挂掉。这里可以用矩阵做优化,复杂度可以降到O(logn * 2^3)
如图:
A = F(n - 1), B = F(N - 2),这样使构造矩阵的n次幂乘以初始矩阵得到的结果就是。
因为是2*2的据称,所以一次相乘的时间复杂度是O(2^3),总的复杂度是O(logn * 2^3 + 2*2*1)。
下面给出一种比较基础的类型的矩阵快速幂:
f(n)= a*f(n-1)+b*f(n-2)型
下面两题适合作为此种矩阵快速幂的模板来使用
http://poj.org/problem?id=3070 (纯模板题,直接用)
#include<iostream> #include<cstdio> #include<cstring> using namespace std; struct Mat { int mat[2][2]; }; Mat d; int n,mod; Mat mul(Mat a,Mat b) { Mat c; memset(c.mat,0,sizeof(c.mat)); for(int i=0;i<n;++i) { for(int k=0;k<n;++k) { if(a.mat[i][k]) for(int j=0;j<n;++j) { c.mat[i][j]+=a.mat[i][k]*b.mat[k][j]; if(c.mat[i][j]>=mod) c.mat[i][j]%=mod; } } } return c; } Mat expo(Mat p,int k) { if(k==1) return p; Mat e; memset(e.mat,0,sizeof(e.mat)); for(int i=0;i<n;++i) e.mat[i][i]=1; if(k==0) return e; while(k) { if(k&1) e=mul(p,e); p=mul(p,p); k>>=1; } return e; } int main() { n=2; mod=10000; d.mat[1][1]=0; d.mat[0][1]=d.mat[1][0]=d.mat[0][0]=1; int k; while(cin>>k) { if(k==-1) break; Mat res=expo(d,k); int ans=res.mat[0][1]%mod; cout<<ans<<endl; } return 0; }
链接:http://codeforces.com/contest/450/problem/B (纯模板题的变形题)
#include<iostream> #include<cstdio> #include<cstring> using namespace std; struct Mat { int mat[2][2]; }; Mat d; int n,mod; Mat mul(Mat a,Mat b) { Mat c; memset(c.mat,0,sizeof(c.mat)); for(int i=0;i<n;++i) { for(int k=0;k<n;++k) { if(a.mat[i][k]) for(int j=0;j<n;++j) { c.mat[i][j]+=a.mat[i][k]*b.mat[k][j]; if(c.mat[i][j]>=mod) c.mat[i][j]%=mod; } } } return c; } Mat expo(Mat p,int k) { if(k==1) return p; Mat e; memset(e.mat,0,sizeof(e.mat)); for(int i=0;i<n;++i) e.mat[i][i]=1; if(k==0) return e; while(k) { if(k&1) e=mul(p,e); p=mul(p,p); k>>=1; } return e; } int main() { n=2; mod=10000; d.mat[1][1]=0; d.mat[0][1]=d.mat[1][0]=d.mat[0][0]=1; int k; while(cin>>k) { if(k==-1) break; Mat res=expo(d,k); int ans=res.mat[0][1]%mod; cout<<ans<<endl; } return 0; }
S = A + A^2 + A^3 + … + A^k类型
链接:http://poj.org/problem?id=3233
给定三个参数n、k、m,n为矩阵的行数和列数,k表示最高次幂,m用于取模。
对于给定的矩阵A,要求输出A^1+A^2+……+A^k的结果矩阵。
求A^i可以使用二分快速幂,这个是足够快的了。
但k最大可以达到10^9,因此虽然题目只有一组数据,但直接一次循环也必然超时。
这里的求和可以采用二分的思想:
对于S=A^1+A^2+……+A^k
若k是偶数,则S=(1+A^(k/2))(A^1+A^2+……+A^(k/2))
若k是奇数,则S=(1+A^(k/2))(A^1+A^2+……+A^(k/2))+A^k
以上的k/2指的是程序中的除法,即舍弃小数的除法。
采用这种二分思想,可以大大减少时间复杂度,因此可以满足题目的要求。
应当注意的是这里要求的结果矩阵是每个元素模m之后的矩阵,可以在运算过程中可能超过m的时候判断一下,对m取模。
#include<iostream> #include<cstdio> #include<cstring> using namespace std; const int maxn=31; struct Mat { int mat[maxn][maxn]; }; Mat d; int n,m; Mat mul(Mat a,Mat b) { Mat c; memset(c.mat,0,sizeof(c.mat)); for(int i=0;i<n;++i) { for(int k=0;k<n;++k) { if(a.mat[i][k]) for(int j=0;j<n;++j) { c.mat[i][j]+=a.mat[i][k]*b.mat[k][j]; if(c.mat[i][j]>=m) c.mat[i][j]%=m; } } } return c; } Mat expo(Mat p,int k) { if(k==1) return p; Mat e; memset(e.mat,0,sizeof(e.mat)); for(int i=0;i<n;++i) e.mat[i][i]=1; while(k) { if(k&1) e=mul(p,e); p=mul(p,p); k>>=1; } return e; } Mat sum(Mat p,int k) { for(int i=0;i<n;++i) { for(int j=0;j<n;++j) { if(p.mat[i][j]>=m) p.mat[i][j]%=m; } } if(k==1) return p; Mat m1=expo(p,k/2); for(int i=0;i<n;++i) m1.mat[i][i]+=1; Mat m2=sum(p,k/2); Mat m3=mul(m1,m2); if(k&1) { Mat temp=expo(p,k); for(int i=0;i<n;++i) { for(int j=0;j<n;++j) { m3.mat[i][j]+=temp.mat[i][j]; if(m3.mat[i][j]>=m) m3.mat[i][j]%=m; } } } return m3; } int main() { int k; while(cin>>n>>k>>m) { for(int i=0;i<n;++i) for(int j=0;j<n;++j) scanf("%d",&d.mat[i][j]); Mat arry=sum(d,k); for(int i=0;i<n;++i) { for(int j=0;j<n;++j) { if(j!=0) printf(" "); printf("%d",arry.mat[i][j]); } printf(" "); } } return 0; }
矩阵变换类型
这种题目可以用矩阵快速幂,从而实现矩阵的多次变换
链接:http://poj.org/problem?id=3735
【题意】:有n只猫咪,开始时每只猫咪有花生0颗,现有一组操作,由下面三个中的k个操作组成:
1. g i 给i只猫咪一颗花生米
2. e i 让第i只猫咪吃掉它拥有的所有花生米
3. s i j 将猫咪i与猫咪j的拥有的花生米交换
现将上述一组操作做m次后,问每只猫咪有多少颗花生?
【题解】:m达到10^9,显然不能直接算。
因为k个操作给出之后就是固定的,所以想到用矩阵,矩阵快速幂可以把时间复杂度降到O(logm)。问题转化为如何构造转置矩阵?
说下我的思路,观察以上三种操作,发现第二,三种操作比较容易处理,重点落在第一种操作上。
有一个很好的办法就是添加一个辅助,使初始矩阵变为一个n+1元组,编号为0到n,下面以3个猫为例:
定义初始矩阵A = [1 0 0 0],0号元素固定为1,1~n分别为对应的猫所拥有的花生数。
对于第一种操作g i,我们在单位矩阵基础上使Mat[0][i]变为1,例如g 1:
1 1 0 0
0 1 0 0
0 0 1 0
0 0 0 1,显然[1 0 0 0]*Mat = [1 1 0 0]
对于第二种操作e i,我们在单位矩阵基础使Mat[i][i] = 0,例如e 2:
1 0 0 0
0 1 0 0
0 0 0 0
0 0 0 1, 显然[1 2 3 4]*Mat = [1 2 0 4]
对于第三种操作s i j,我们在单位矩阵基础上使第i列与第j互换,例如s 1 2:
1 0 0 0
0 0 0 1
0 0 1 0
0 1 0 0,显然[1 2 0 4]*Mat = [1 4 0 2]
现在,对于每一个操作我们都可以得到一个转置矩阵,把k个操作的矩阵相乘我们可以得到一个新的转置矩阵T。
A * T 表示我们经过一组操作,类似我们可以得到经过m组操作的矩阵为 A * T ^ m,最终矩阵的[0][1~n]即为答案。
上述的做法比较直观,但是实现过于麻烦,因为要构造k个不同矩阵。
有没有别的方法可以直接构造转置矩阵T?答案是肯定的。
我们还是以单位矩阵为基础:
对于第一种操作g i,我们使Mat[0][i] = Mat[0][i] + 1;
对于第二种操作e i,我们使矩阵的第i列清零;
对于第三种操作s i j,我们使第i列与第j列互换。
这样实现的话,我们始终在处理一个矩阵,免去构造k个矩阵的麻烦。
至此,构造转置矩阵T就完成了,接下来只需用矩阵快速幂求出 A * T ^ m即可,还有一个注意的地方,该题需要用到long long。
具体实现可以看下面的代码。
个人采用的是第二种方法
#include <iostream> #include <cstring> #include <cstdio> #define LL long long using namespace std; struct met{ LL at[105][105]; }; met ret,d; LL n,m,k; met mul(met a,met b) { memset(ret.at,0,sizeof(ret.at)); for(int i=0;i<=n;++i) { for(int k=0;k<=n;++k) { if(a.at[i][k]) { for(int j=0;j<=n;++j) { ret.at[i][j]+=a.at[i][k]*b.at[k][j]; } } } } return ret; } met expo(met a,LL k) { if(k==1) return a; met e; memset(e.at,0,sizeof(e.at)); for(int i=0;i<=n;++i){e.at[i][i]=1;} if(k==0)return e; while(k) { if(k&1)e=mul(e,a); k>>=1; a=mul(a,a); } return e; } int main() { while(~scanf("%lld%lld%lld",&n,&m,&k)) { LL a,b; char ch[5]; if(!n&&!k&&!m)break; memset(d.at,0,sizeof(d.at)); for(int i=0;i<=n;++i) {d.at[i][i]=1;} while(k--) { scanf("%s",ch); if(ch[0]=='g') { scanf("%lld",&a); d.at[0][a]++; } else if(ch[0]=='e') { scanf("%lld",&a); for(int i=0;i<=n;++i) { d.at[i][a]=0; } } else { scanf("%lld%lld",&a,&b); for(int i=0;i<=n;++i) { LL t=d.at[i][a]; d.at[i][a]=d.at[i][b]; d.at[i][b]=t; } } } met ans=expo(d,m); printf("%lld",ans.at[0][1]); for(int i=2;i<=n;++i) { printf(" %lld",ans.at[0][i]); } printf(" "); } return 0; }