这题什么素质,爆long long就算了,连int128都爆……最后还是用long double卡过的……而且可能是我本身自带大常数吧,T了好长时间……
先说一下超级汇点的计数吧,先说结论:
1.将所有点(此题中只有一级点)向一个超级汇点0连边,将矩阵乘n次,相应的f[i][j]即为从i到j的走n步方案数,f[i][0]为i到0走n步的方案数,若在给他乘一个ans矩阵(ans在前),则f[0][0]-n(点数)为所有长度等于n(指数)的路径的方案数。ans矩阵为0向所有其他点连边。
2.若在1中,将0想自己连边,则每次相乘都会积累,最终得出的即为所有长度小于等于n的路径方案数。
具体可以这样理解:
ans矩阵相当于从超级汇点出发走一步,每乘一个base矩阵,相当于走一步,乘了n次后,相当于走n步,但是还要再乘一个base,相当于各点回到0,而计数器中仍保留着从0走出的方案数,此时f[0][0]-点数即为答案。(这种问题自己手模一下会更容易理解吧)。
然后是题解:
边权只有1,2,3三种,考虑拆点(在‘迷路’中也用到了同样的方法),将一个点分为三级,$g
et
(int po,int w){return (po-1)*3+w;}
$,将每个点的第一级向第二级连边,第二级向第三级连边,对于一条a->b,长度为w的边,从a的第w级向b的第一级连长度为1的边。
代码实现:
1 for(int i=1;i<=m;i++) 2 { 3 a=read(),b=read(),c=read(); 4 cs.m[get(a,c)][get(b,1)]++; 5 } 6 for(int i=1;i<=n;i++) 7 { 8 cs.m[get(i,1)][get(i,2)]++; 9 cs.m[get(i,2)][get(i,3)]++; 10 }
这样就得到了一个初始矩阵,构造出ans矩阵,显然可以二分枚举长度解决,但是复杂度比较高会T,考虑倍增,提前预处理出初始矩阵乘$2^i$后的矩阵,像LCA那样搞就可以了。
然而这道题还有几个坑点:
方案数乘的时候会爆longlong(如果你打的恶心点连__int128都会爆),可以加判断,个人感觉比较麻烦,于是就用了double,还会爆?丝毫不慌还有long double。
然后就T了,用lemon测了一下,跑了一百多秒,好在都跑对了,其实这不是long double的锅,和我自带的大常数关系也不大,在预处理倍增数组时我固定给他求到了65,导致时间比较长,其实可以记录一下:
1 for(int i=1;i<=65;i++,imax++){F[i]=F[i-1]*F[i-1];if((ans*F[i]).count()>k)break;}
然后就A了,跑得还挺快。其实我还是搞不懂为啥会差这么多,固定求到65复杂度也是$n^3log_n$啊……
#include<iostream> #include<cstring> #include<cstdio> #include<cmath> #define N n*3 #define LD long double #define LL long long using namespace std; int n,m;LL k; struct jz { LD m[121][121]; LD count() {return m[0][0]-n;} }cs,ans,F[70]; jz operator * (jz &a,jz &b) { jz ans; for(int i=0;i<=N;i++) for(int j=0;j<=N;j++) ans.m[i][j]=0; for(int i=0;i<=N;i++) for(int j=0;j<=N;j++) for(int k=0;k<=N;k++) ans.m[i][j]+=a.m[i][k]*b.m[k][j]; return ans; } inline int get(const register int po,const register int w){return (po-1)*3+w;} inline LL read() { LL s=0;char a=getchar(); while(a<'0'||a>'9')a=getchar(); while(a>='0'&&a<='9'){s=s*10+a-'0';a=getchar();} return s; } signed main() { // freopen("10.in","r",stdin); n=read(),m=read(),k=read(); int a,b,c; for(int i=1;i<=m;i++) { a=read(),b=read(),c=read(); cs.m[get(a,c)][get(b,1)]++; } for(int i=1;i<=n;i++) { cs.m[get(i,1)][get(i,2)]++; cs.m[get(i,2)][get(i,3)]++; } LL imax=1; cs.m[0][0]=1; for(int i=1;i<=n;i++)cs.m[get(i,1)][0]++; for(int i=1;i<=n;i++)ans.m[0][get(i,1)]=1; F[0]=cs; for(int i=1;i<=65;i++,imax++){F[i]=F[i-1]*F[i-1];if((ans*F[i]).count()>k)break;} if((ans*F[imax]).count()<k){cout<<-1<<endl;return 0;} LL num=0; for(int i=imax;i>=0;i--) { jz tm=ans*F[i]; if(tm.count()<k){num+=1ll<<i;ans=ans*F[i];} } cout<<num<<endl; }