题目分析:
用数论分块的思想,就会发现其实就是连续一段的长度$i$的高度不能超过$lfloor frac{k}{i}
floor$,然后我们会发现最长的非$0$一段不会超过$k$,所以我们可以弄一个长度为$i$的非$0$段的个数称为"元",然后用
"元"去递推。
这个"元"的求法用DP:令数论分块之后第$i$段的长度为$g[i]$
$$f[i][j] = f[i-1][j] + f[i-1][k]*f[i][j-k-1]*g[i]$$
$$f[1][1] = 1$$
然后接下来就是一个很简单的矩阵快速幂了。
对于$k$很大的情况用多项式取模优化就行了。
代码:
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 const int mod = 998244353; 5 6 int n,k,x,y; 7 int maxx[1060],pos[1060],cas[150],num; 8 int g[120][1060],f[1060]; 9 10 int FF[1060],ans[1060],dt[1060],zeta[2060]; 11 12 int fast_pow(int now,int pw){ 13 int ans = 1,dt = now,bit = 1; 14 while(bit <= pw){ 15 if(bit & pw) ans = 1ll*ans*dt%mod; 16 bit<<=1; dt = 1ll*dt*dt%mod; 17 } 18 return ans; 19 } 20 21 void multi(int *A,int *B,int len){ 22 memset(zeta,0,sizeof(zeta)); 23 for(int i=0;i<len;i++){ 24 for(int j=0;j<len;j++) zeta[i+j]+=1ll*A[i]*B[j]%mod,zeta[i+j]%=mod; 25 } 26 for(int i=len*2;i>=len;i--){ 27 if(zeta[i] == 0) continue; 28 int pp = zeta[i],dr = (FF[len]==1?1:mod-1); 29 for(int j=i,kk=len;kk>=0;kk--,j--){ 30 zeta[j] -= 1ll*pp*FF[kk]%mod*dr%mod; 31 if(zeta[j] <0) zeta[j] += mod; 32 } 33 } 34 for(int i=0;i<len;i++) A[i] = zeta[i]; 35 } 36 37 void fastpow(int len,int pw){ 38 memset(ans,0,sizeof(ans)); ans[0] = 1; 39 memset(dt,0,sizeof(dt)); dt[1] = 1; 40 int bit = 1; 41 while(bit<=pw){ 42 if(bit & pw) multi(ans,dt,len); 43 multi(dt,dt,len); bit<<=1; 44 } 45 } 46 47 void getbase(int now){ 48 memset(g,0,sizeof(g)); 49 g[0][0] = 1; 50 for(int i=1;i<=num;i++){ 51 g[i][0] = 1; 52 for(int j=1;j<=now/cas[i];j++){ 53 g[i][j] += g[i-1][j]; g[i][j] %= mod; 54 for(int l=0;l<j;l++){ 55 for(int ec = cas[i+1]+1;ec<=cas[i];ec++){ 56 g[i][j]+=1ll*g[i-1][l]%mod*g[i][j-l-1]%mod*pos[ec]%mod; 57 g[i][j] %= mod; 58 } 59 } 60 } 61 } 62 } 63 64 int work(int now){ 65 memset(cas,0,sizeof(cas)); memset(maxx,0,sizeof(maxx)); 66 memset(f,0,sizeof(f)); 67 num = 0; 68 for(int i=1;i<=now;i++) { 69 maxx[i] = now/i; 70 if(maxx[i] != maxx[i-1]) cas[++num] = maxx[i]; 71 } 72 getbase(now); 73 f[0] = 1; 74 for(int i=1;i<=1000;i++){ 75 for(int j=0;j<=now;j++){ 76 if(j > i) break; 77 if(j == i){f[i] += g[num][i]; f[i] %= mod;} 78 else{f[i]+=1ll*f[i-j-1]*pos[0]%mod*g[num][j]%mod;f[i]%=mod;} 79 } 80 } 81 if(n <= 1000) return f[n]; 82 else{ 83 memset(FF,0,sizeof(FF)); 84 FF[now+1] = 1; 85 for(int i=1;i<=now+1;i++) 86 FF[now+1-i]=(mod-1ll*g[num][i-1]*pos[0]%mod)%mod; 87 if(!(n&1)){for(int i=0;i<=now+1;i++) FF[i] = (mod-FF[i])%mod;} 88 fastpow(now+1,n); 89 int res = 0; 90 for(int i=0;i<=now;i++){res += 1ll*ans[i]*f[i]%mod; res %= mod;} 91 return res; 92 } 93 } 94 95 int main(){ 96 scanf("%d%d%d%d",&n,&k,&x,&y); 97 x = 1ll*x*fast_pow(y,mod-2)%mod; 98 for(int i=0;i<=k;i++){pos[i] = 1ll*fast_pow(x,i)*(mod+1-x)%mod;} 99 if(n == 1){printf("%d ",pos[k]);return 0;} 100 int ans = work(k)-work(k-1); 101 if(ans < 0) ans += mod; 102 printf("%d ",ans); 103 return 0; 104 }