Description
给出一张n个点的有向图G(V,E)。对于任意两个点u,v(u可以等于v),u向v的连边数为:
∑OUT(u,i) * IN(v,i),其中1<=i<=K
其中k和数组out,in均已知,现在给出m个询问,每次询问给出三个参数u,v,d,你需要回答从节点
u出发,经过不超过d条边到达节点v的路径有多少种。答案模10^9+7。
Input
第一行两个整数n,k。
接下来n行,第i行有2k个整数,前k个整数描述outi,后k个数描述ini。
接下来一行一个整数m。
接下来m行,每行三个整数u,v,d(0<=d<2^31),描述一组询问。
Output
对于每个询问,输出一行一个整数,描述答案。
Sample Input
5 2
2 5 4 3
7 9 2 4
0 1 5 2
6 3 9 2
2147483647 1000000001 233522 788488
10
1 1 0
2 2 1
2 4 5
4 3 10
3 4 50
1 5 1000
2 5 4 3
7 9 2 4
0 1 5 2
6 3 9 2
2147483647 1000000001 233522 788488
10
1 1 0
2 2 1
2 4 5
4 3 10
3 4 50
1 5 1000
Sample Output
1
51
170107227
271772358
34562176
890241289
51
170107227
271772358
34562176
890241289
HINT
1<=N<=1000
1<=K<=20
1<=M<=50
题解:
将in转置一下
设G[u][v]=u到v直接的路径条数,显然g=out*in
然后答案矩阵显然是Gd
但G是1000*1000的矩阵,直接写一定会T
注意到Gd=(out*in)d=out*(in*out)d-1*in
而in*out是20*20的矩阵
然后求答案时并不用把Gd求出来,我们只需要u到v的路径数,不然复杂度还是过不去
注意常数(卡蠢我了)
code:
1 #include<cstdio> 2 #include<iostream> 3 #include<cmath> 4 #include<cassert> 5 #include<cstring> 6 #include<algorithm> 7 #define maxn 21 8 #define mod 1000000007 9 using namespace std; 10 char ch; 11 bool ok; 12 void read(int &x){ 13 for (ok=0,ch=getchar();!isdigit(ch);ch=getchar()) if (ch=='-') ok=1; 14 for (x=0;isdigit(ch);x=x*10+ch-'0',ch=getchar()); 15 if (ok) x=-x; 16 } 17 int n,m,q,a,b,d; 18 int out[1005][maxn],in[maxn][1005],list[maxn],ans; 19 struct Matrix{ 20 int v[maxn][maxn]; 21 void init(int op){ 22 for (int i=1;i<=m;i++) for (int j=1;j<=m;j++) v[i][j]=(i==j)*op; 23 } 24 }base,I,emp,tmp,res,sum,last; 25 Matrix operator+(const Matrix &a,const Matrix &b){ 26 static Matrix c; 27 c=emp; 28 for (int i=1;i<=m;i++) for (int j=1;j<=m;j++) c.v[i][j]=(a.v[i][j]+b.v[i][j])%mod; 29 return c; 30 } 31 Matrix operator*(const Matrix &a,const Matrix &b){ 32 static Matrix c; 33 c=emp; 34 for (int i=1;i<=m;i++) for (int j=1;j<=m;j++) for (int k=1;k<=m;k++) 35 c.v[i][k]=(c.v[i][k]+1LL*a.v[i][j]*b.v[j][k])%mod; 36 return c; 37 } 38 void solve(int n){//1+a+a^2+a^3+...+a^n 39 if (n<0){res=emp;return;} 40 Matrix a=base,b=a,v=I,s=I; 41 for(int i=n;i;i>>=1,b=b*(a+I),a=a*a) 42 if(i&1) s=s+b*v,v=v*a; 43 res=s; 44 } 45 int main(){ 46 read(n),read(m),I.init(1),emp.init(0); 47 for (int i=1;i<=n;i++){ 48 for (int j=1;j<=m;j++) read(out[i][j]); 49 for (int j=1;j<=m;j++) read(in[j][i]); 50 } 51 for (int i=1;i<=m;i++) for (int j=1;j<=m;j++) base.v[i][j]=0; 52 for (int i=1;i<=m;i++) for (int j=1;j<=n;j++) for (int k=1;k<=m;k++) 53 base.v[i][k]=(base.v[i][k]+1LL*in[i][j]*out[j][k])%mod; 54 read(q); 55 while (q--){ 56 read(a),read(b),read(d); 57 solve(d-1); 58 for (int i=1;i<=m;i++) list[i]=0; 59 for (int i=1;i<=m;i++) for (int j=1;j<=m;j++) list[j]=(list[j]+1LL*out[a][i]*res.v[i][j])%mod; 60 ans=0; 61 for (int i=1;i<=m;i++) ans=(ans+1LL*list[i]*in[i][b])%mod; 62 printf("%d ",ans+(a==b)); 63 } 64 return 0; 65 }