细节比较多,好多地方容易写挂.
code:
#include <cstdio> #include <map> #include <string> #include <algorithm> #define N 200005 #define ll long long #define MAXN 11000000 #define mod 998244353 using namespace std; namespace IO { char buf[100000],*p1,*p2; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int rd() { int x=0; char s=nc(); while(s<'0') s=nc(); while(s>='0') x=(((x<<2)+x)<<1)+s-'0',s=nc(); return x; } void print(int x) {if(x>=10) print(x/10);putchar(x%10+'0');} void setIO(string s) { string in=s+".in"; string out=s+".out"; freopen(in.c_str(),"r",stdin); freopen(out.c_str(),"w",stdout); } }; int fac[MAXN],inv[MAXN],a[N],bu[MAXN],in[MAXN]; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) tmp=(ll)tmp*x%mod; return tmp; } void solve() { using namespace IO; int n,m,l,r,ty=0,len,tot,i,j=0,k; n=rd(),m=rd(),l=rd(),r=rd(),len=r-l+1,tot=fac[n+m]; for(i=1;i<=n;++i) a[i]=rd(); sort(a+1,a+1+n); de2+=n; for(i=1;i<=n;i=k) { k=i; while(a[k]==a[i]&&k<=n) ++k; int cur=k-i; if(a[i]>=l&&a[i]<=r) { ++bu[cur]; ++ty; j=max(j,cur); tot=(ll)tot*inv[cur]%mod; } else { tot=(ll)tot*inv[cur]%mod; } } bu[0]=len-ty; for(i=0;i<=j;++i) { tot=(ll)tot*qpow(in[i+1],min(bu[i],m))%mod; m-=min(bu[i],m); if(!m) break; bu[i+1]+=bu[i]; } if(i==j+1&&m) { int t=bu[j+1]; int st=j+2; int ed=j+1+(m/t); int remain=m-(m/t)*t; int de=qpow((ll)fac[ed]*inv[st-1]%mod,t); tot=(ll)tot*qpow(de,mod-2)%mod; if(remain) tot=(ll)tot*qpow(qpow(ed+1,remain),mod-2)%mod; } printf("%d ",tot); for(i=0;i<=j+1;++i) bu[i]=0; } int main() { using namespace IO; // setIO("input"); int T,i,j; T=rd(); fac[0]=1; for(i=1;i<MAXN;++i) fac[i]=(ll)fac[i-1]*i%mod; inv[MAXN-1]=qpow(fac[MAXN-1],mod-2); for(i=MAXN-1;i;--i) inv[i-1]=(ll)inv[i]*i%mod; in[0]=in[1]=1; for(i=2;i<N;++i) in[i]=(ll)(mod-mod/i)*in[mod%i]%mod; while(T--)solve(); return 0; }