好久没写题解了,就过来水两篇。
对于每一个人,考虑一个序列$A$,$A_I$表示当k取值为 i 时的答案。
如果说有两个人,我们可以把$(A+B)^k$二项式展开,这样就发现把两个人合并起来的操作就是一次卷积,直接NTT就可以了。
同类人有多个,直接暴力肯定是不行的。快速幂的话不知道会不会T,我是用了多项式取ln和exp(拉板子)。
#include<cmath> #include<cstdio> #include<cstring> #include<algorithm> #define MN 400002 using namespace std; int read_p,read_ca; inline int read(){ read_p=0;read_ca=getchar(); while(read_ca<'0'||read_ca>'9') read_ca=getchar(); while(read_ca>='0'&&read_ca<='9') read_p=read_p*10+read_ca-48,read_ca=getchar(); return read_p; } const int MOD=998244353,g=3; inline int mi(int x,int y){ int mmh=1; while (y){ if (y&1) mmh=1LL*mmh*x%MOD; y>>=1;x=1LL*x*x%MOD; } return mmh; } int tot,k,n,m,f[MN],mmh,I[MN],_I[MN],L[MN],R[MN],MMH,N[MN],A[MN],B[MN],e[MN],_e[MN],W[MN],C_a[MN],C_b[MN],N_c[MN],C[MN],D[MN],Q[MN],_A[MN],_B[MN],ANS[MN]; inline void M(int &x){while(x>=MOD)x-=MOD;while(x<0)x+=MOD;} inline int ask(int n,int k){ for (int i=1;i<=k+1;i++) N[i]=mi(i,k); for (int i=1;i<=k+1;i++) M(N[i]+=N[i-1]); n%=MOD; for (int i=0;i<=k+1;i++) L[i]=R[i]=n-i; for (int i=1;i<=k+1;i++) L[i]=1ll*L[i-1]*L[i]%MOD; for (int i=k;i>=0;i--) R[i]=1ll*R[i+1]*R[i]%MOD; mmh=0; for (int i=0;i<=k+1;i++){ MMH=N[i]; if (i>0) MMH=1LL*MMH*_I[i]%MOD*L[i-1]%MOD; if (i<k+1) MMH=1LL*MMH*_I[k+1-i]%MOD*((k+1-i)%2?-1:1)*R[i+1]%MOD; M(mmh+=MMH); } return mmh; } inline void inv(){ int base=mi(g,(MOD-1)/tot),_base=mi(base,MOD-2); e[0]=_e[0]=1; for (int i=1;i<=tot;i++) e[i]=1LL*e[i-1]*base%MOD,_e[i]=1LL*_e[i-1]*_base%MOD; } inline void NTT(int N,int a[],int w[]){ for (int j,i=j=0;i<N;i++){ if (i>j) swap(a[i],a[j]); for (int k=N>>1;(j^=k)<k;k>>=1); } for (int i=2;i<=N;i<<=1){ for (int k,j=k=0,s=tot/i;k<(i>>1);j+=s,k++) W[k]=w[j]; for (int m=i>>1,j=0;j<N;j+=i) for (int k=0;k<m;k++){ int A=j+k,B=A+m,z=1LL*a[B]*W[k]%MOD; M(a[B]=a[A]-z); M(a[A]+=z); } } } inline void cc(int n,int m,int a[],int b[],int c[]){ int N=1,i;while (N<(n+m)) N<<=1; for (i=0;i<n;i++) C_a[i]=a[i];fill(C_a+n,C_a+N,0); for (i=0;i<m;i++) C_b[i]=b[i];fill(C_b+m,C_b+N,0); NTT(N,C_a,e);NTT(N,C_b,e); for (i=0;i<N;i++) c[i]=1LL*C_a[i]*C_b[i]%MOD; NTT(N,c,_e); int w=mi(N,MOD-2); for (i=0;i<N;i++) c[i]=1LL*c[i]*w%MOD; } inline void _D(int n,int a[],int b[]){for (int i=0;i<n;i++) b[i]=1LL*(i+1)*a[i+1]%MOD;b[n]=0;} inline void _S(int n,int a[],int b[]){for (int i=n;i;i--) b[i]=1LL*a[i-1]*I[i]%MOD;b[0]=0;} void ny(int n,int a[],int b[]){ if (n==1) memset(b,0,sizeof(int)*tot),b[0]=mi(a[0],MOD-2);else{ ny((n+1)>>1,a,b); register int i; int N=1;while (N<(n<<1)) N<<=1; copy(a,a+n,N_c);fill(N_c+n,N_c+N,0); NTT(N,N_c,e);NTT(N,b,e); for (i=0;i<N;i++) b[i]=(2LL-1LL*N_c[i]*b[i]%MOD+MOD)*b[i]%MOD; NTT(N,b,_e); int w=mi(N,MOD-2); for (i=0;i<n;i++) b[i]=1LL*b[i]*w%MOD;fill(b+n,b+N,0); } } void sqrt(int n,int a[],int b[]){ if (n==1) memset(b,0,sizeof(int)*tot),b[0]=int(sqrt(a[0])+0.5);else{ sqrt((n+1)>>1,a,b); register int i; int N=1,w=I[2];while (N<(n<<1)) N<<=1; copy(b,b+n,D);fill(D+n,D+N,0); for (i=0;i<n;i++) M(D[i]<<=1); ny(n,D,C); cc(n,n,a,C,C); for (i=0;i<n;i++) b[i]=(1LL*w*b[i]+C[i])%MOD; } } inline void Ln(int n,int a[],int b[]){ memset(C,0,sizeof(int)*tot);memset(D,0,sizeof(int)*tot); _D(n,a,D);ny(n,a,C); cc(n,n,D,C,b); _S(n,b,b); } void exp(int n,int a[],int b[]){ if (n==1) memset(b,0,sizeof(int)*tot),b[0]=1;else{ exp((n+1)>>1,a,b); Ln(n,b,Q); int N=1,w=(MOD+1)>>1;while (N<(n<<1)) N<<=1; for (int i=0;i<n;i++) M(Q[i]=a[i]-Q[i]);M(Q[0]+=1); cc(n,n,Q,b,b); fill(b+n,b+N,0); } } void work(int n,int C[]){ if (n==0){ C[0]=1; for (int i=1;i<k;i++) C[i]=0; } for (int i=0;i<k;i++) A[i]=_I[i+1]; ny(k,A,B); for (int i=0;i<k;i++) A[i]=1LL*mi(n+1,i+1)*_I[i+1]%MOD; cc(k,k,A,B,C); } int main(){ scanf("%d%d%d",&k,&m,&n);k++; for(tot=1;tot<(k<<1);tot<<=1);inv(); I[1]=1;for (int i=2;i<MN;i++) I[i]=1LL*(MOD-MOD/i)*I[MOD%i]%MOD; f[0]=_I[0]=1;for (int i=1;i<MN;i++) _I[i]=1LL*_I[i-1]*I[i]%MOD,f[i]=1LL*f[i-1]*i%MOD; //scanf("%d%d%d",&k,&m,&n); //n=3;k=10; /*for (int i=0;i<=k;i++) S[i]=1LL*ask(n,i)*I[i]%MOD; for (int i=0;i<=k;i++){ int o=0; for (int j=0;j<=i;j++) o=(1LL*S[i]*I[j+1]+o)%MOD; printf("%d ",o); } puts(""); for (int i=0;i<=k;i++) printf("%d ",1LL*(mi(n+1,i+1)-1)*I[i+1]%MOD); puts(""); */ ANS[0]=1; for (int i=1;i<=n;i++){ int a,b,c; scanf("%d%d%d",&a,&b,&c); work(a-1,_A);work(b,_B); //for (int i=0;i<k;i++) printf("%d ",1ll*_B[i]*f[i]%MOD);puts(""); int tmp=mi(b-a+1,MOD-2); for (int i=0;i<k;i++) M(_B[i]-=_A[i]),_B[i]=1LL*_B[i]*tmp%MOD; Ln(k,_B,_A); for (int i=0;i<k;i++) _A[i]=1LL*_A[i]*c%MOD; exp(k,_A,_B); cc(k,k,ANS,_B,ANS); //for (int i=0;i<k;i++) printf("%d ",1LL*ANS[i]*f[i]);puts(""); } printf("%d ",1LL*ANS[k-1]*f[k-1]%MOD); }