裸做的话设一个 $p[i][j]$ 表示两个堆分别抽走 $i,j$ 个的概率.
转移的话就枚举当前是第几个,然后再枚举左/右面由下向上第几个贡献.
不在模意义下做,开 double 打表发现无论怎样洗牌,一次函数还是一次函数,二次函数还是二次函数.
那么我们只需暴力维护出牌的前 3 项,然后后面的项用拉格朗日插值求出即可.
code:
#include <cstdio> #include <cstring> #include <algorithm> #define N 500009 #define ll long long #define mod 998244353 #define setIO(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout) using namespace std; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) { if(y&1) tmp=(ll)tmp*x%mod; } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } namespace Lagrange { int x[5],y[5],dn[5]; void init() { for(int i=1;i<=3;++i) { dn[i]=1; for(int j=1;j<=3;++j) { if(i==j) continue; dn[i]=(ll)(x[i]-x[j]+mod)%mod*dn[i]%mod; } dn[i]=get_inv(dn[i]); } } int solve(int v) { int an=0; for(int i=1;i<=3;++i) { int up=1; for(int j=1;j<=3;++j) { if(i==j) continue; up=(ll)(v-x[j]+mod)%mod*up%mod; } (an+=(ll)y[i]*up%mod*dn[i]%mod)%=mod; } return an; } }; int n,m,ty; int a[N],tmp[10000009],A[N],p[4][4],inv[10000008]; void init() { inv[1]=1; for(int i=2;i<10000008;++i) { inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; } inv[0]=1; } void calc(int tmp) { memset(p,0,sizeof(p)); p[0][0]=1; int na=tmp,nb=n-tmp; for(int i=0;i<=min(3,na);++i) { for(int j=0;j<=min(3,nb);++j) { if(!i&&!j) continue; int tot=na-i+1+nb-j; if(i) { (p[i][j]+=(ll)p[i-1][j]*(na-i+1)%mod*inv[tot]%mod)%=mod; } if(j) { (p[i][j]+=(ll)p[i][j-1]*(nb-j+1)%mod*inv[tot]%mod)%=mod; } } } } int main() { // setIO("landlords"); scanf("%d%d%d",&n,&m,&ty); for(int i=n-2;i<=n;++i) { Lagrange::x[n-i+1]=i; Lagrange::y[n-i+1]=(ty==1?i:(ll)i*i%mod); } Lagrange::init(); init(); for(int i=1;i<=m;++i) { scanf("%d",&A[i]); } for(int i=1;i<=m;++i) { calc(A[i]); for(int j=1;j<=3;++j) { int cur=n-j+1,na=A[i],nb=n-A[i]; tmp[cur]=0; for(int k=1;k<=min(na,j);++k) { if(j-k<=nb) (tmp[cur]+=(ll)p[k-1][j-k]*(na-k+1)%mod*inv[n-j+1]%mod*Lagrange::solve(na-k+1)%mod)%=mod; } for(int k=1;k<=min(nb,j);++k) { if(j-k<=na) (tmp[cur]+=(ll)p[j-k][k-1]*(nb-k+1)%mod*inv[n-j+1]%mod*Lagrange::solve(n-k+1)%mod)%=mod; } } for(int j=1;j<=3;++j) { Lagrange::x[j]=n-j+1; Lagrange::y[j]=tmp[n-j+1]; } } int Q,x,y,z; scanf("%d",&Q); for(int i=1;i<=Q;++i) { scanf("%d",&x); printf("%d ",Lagrange::solve(x)); } return 0; }