题解:
一看$n$就知道是矩乘加速递推。
问题是怎么推。
不妨认为生成树的边是大号指向小号的。
首先,对于一般节点$x$,$x$可以连到$x-k$,但是连不到$x-k-1$。(废话)
所以我们处理点$x$时要确保$x-k$已经在前面的生成树里面了。
然后就是状态的问题。
由于$x$只能连到$x-k$到$x-1$,我们可以只处理$k$个数的分组状态。
$k$最大为$5$,$5$个数有多少状态?
数一下。
{5}:*1;
{4,1}:*5;
{3,2}:*10;
{3,1,1}:*10;
{2,2,1}:*15;
{2,1,1,1}:*10;
{1,1,1,1,1}:*1;
一共$52$种。
那$52*52$的矩阵就很好了。
先搜出所有状态,然后预处理转移矩阵,最后矩乘就好了。
代码:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define ll long long #define N 60 #define MOD 65521 ll n; int k,cnt; struct tgb { int a[7]; }g[N]; int to[47000];//hash 6 int t[7],las[7]; struct Matrix { ll s[N][N]; Matrix(){memset(s,0,sizeof(s));} }j0,j1; Matrix operator * (Matrix &a,Matrix &b) { Matrix c; for(int i=1;i<=cnt;i++) for(int j=1;j<=cnt;j++) for(int k=1;k<=cnt;k++) (c.s[i][j] += a.s[i][k]*b.s[k][j])%=MOD; return c; } Matrix operator ^ (Matrix &a,ll b) { Matrix c; c = a,b--; while(b) { if(b&1)c=c*a; a=a*a; b>>=1; } return c; } ll fastpow(ll x,int y) { ll ret = 1; while(y>0) { if(y&1)ret = ret*x%MOD; x = x*x%MOD; y>>=1; } return ret%MOD; } int vis[7],tim,jc[7]; void dfs(int dep,int mx) { if(dep==k+1) { cnt++;int hs = 0; for(int i=1;i<=k;i++)g[cnt].a[i]=t[i],hs=hs*6+t[i]; to[hs] = cnt; memset(vis,0,sizeof(vis)); for(int i=1;i<=k;i++)vis[t[i]]++; j0.s[cnt][1]=1; for(int i=1;vis[i];i++)j0.s[cnt][1]*=jc[vis[i]]; return ; } for(int i=1;i<=mx;i++) { t[dep] = i; dfs(dep+1,mx); } t[dep] = mx+1; dfs(dep+1,mx+1); } int get_nam() { int hs = 0; for(int i=1;i<=k;i++)hs=hs*6+t[i]; return to[hs]; } void init() { for(int i=1;i<=cnt;i++) { for(int j=0;j<=k-1;j++)t[j] = las[j] = g[i].a[j+1]; t[k]=las[k]=6; for(int j=0;j<(1<<k);j++) { bool ot = 0; for(int o=0;o<=k-1;o++) if(j&(1<<o)) { int tg = t[o]; if(t[o]==6) { ot = 1; break; } for(int u=0;u<=k-1;u++) if(t[u]==tg)t[u]=6; } if(ot) { for(int o=0;o<=k;o++)t[o]=las[o]; continue; } memset(vis,0,sizeof(vis)); tim = 0; for(int o=1;o<=k;o++) { if(!vis[t[o]])vis[t[o]]=++tim; t[o] = vis[t[o]]; } if(vis[t[0]])j1.s[get_nam()][i]++; for(int o=0;o<=k;o++)t[o]=las[o]; } } } int main() { scanf("%d%lld",&k,&n); if(n<=k) { ll ans = fastpow(n,n-2); printf("%lld ",ans); return 0; } for(int i=1;i<=k;i++)jc[i]=fastpow(i,i-2); dfs(1,0); init(); Matrix ans = j1^(n-k); ans = ans*j0; printf("%lld ",ans.s[1][1]); return 0; }