没有限制的话算一个组合数就好了。对于不小于某个数的限制可以直接减掉,而不大于某个数的限制很容易想到容斥,枚举哪些超过限制即可。
一般情况下n、m、p都是1e9级别的组合数没办法算。不过可以发现模数已经被给出,并且这些模数的最大质因子幂都不是很大,那么扩展lucas就可以了。
#include<iostream> #include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<algorithm> using namespace std; int read() { int x=0,f=1;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } int T,P,n,n1,n2,m,ans,a[9]; int p[9],b[9],c[9],s[9],t,f[9][11000]; void inc(int &x,int y,int p){x+=y;if (x>=p) x-=p;} void exgcd(int a,int b,int &x,int &y) { if (b==0) { x=1,y=0; return; } exgcd(b,a%b,x,y); int t=x;x=y;y=t-a/b*x; } int inv(int a,int p) { int x,y; exgcd(a,p,x,y); return (x+p)%p; } int ksm(int a,int k,int p) { if (k==0) return 1; int tmp=ksm(a,k>>1,p); if (k&1) return 1ll*tmp*tmp%p*a%p; else return 1ll*tmp*tmp%p; } int fac(int n,int i) { if (n==0) return 1; return 1ll*fac(n/p[i],i)*ksm(f[i][c[i]],n/c[i],c[i])%c[i]*f[i][n%c[i]]%c[i]; } int C(int n,int m,int i) { int s=0; for (long long j=p[i];j<=n;j*=p[i]) s+=n/j; for (long long j=p[i];j<=m;j*=p[i]) s-=m/j; for (long long j=p[i];j<=n-m;j*=p[i]) s-=(n-m)/j; if (s>=b[i]) return 0; return 1ll*fac(n,i)*inv(fac(m,i),c[i])%c[i]*inv(fac(n-m,i),c[i])%c[i]*ksm(p[i],s,c[i])%c[i]; } int crt() { int ans=0; for (int i=1;i<=t;i++) inc(ans,1ll*s[i]*(P/c[i])%P*inv(P/c[i],c[i])%P,P); return ans; } int calc(int n,int m) { if (n<m) return 0; for (int i=1;i<=t;i++) s[i]=C(n,m,i); return crt(); } void dfs(int k,int s,int m) { if (k>n1) { if (s&1) inc(ans,(P-calc(m-1,n-1))%P,P); else inc(ans,calc(m-1,n-1),P); return; } dfs(k+1,s+1,m-a[k]); dfs(k+1,s,m); } int main() { #ifndef ONLINE_JUDGE freopen("bzoj3129.in","r",stdin); freopen("bzoj3129.out","w",stdout); const char LL[]="%I64d"; #else const char LL[]="%lld"; #endif T=read(),P=read(); if (P==10007) t=1,p[1]=10007,b[1]=1,c[1]=10007; else if (P==262203414) { t=5; p[1]=2,p[2]=3,p[3]=11,p[4]=397,p[5]=10007; b[1]=1,b[2]=1,b[3]=1,b[4]=1,b[5]=1; c[1]=2,c[2]=3,c[3]=11,c[4]=397,c[5]=10007; } else { t=3; p[1]=5,p[2]=7,p[3]=101; b[1]=3,b[2]=3,b[3]=2; c[1]=125,c[2]=343,c[3]=10201; } for (int i=1;i<=t;i++) { f[i][0]=1; for (int j=1;j<=c[i];j++) if (j%p[i]==0) f[i][j]=f[i][j-1]; else f[i][j]=1ll*f[i][j-1]*j%c[i]; } while (T--) { n=read(),n1=read(),n2=read(),m=read(); for (int i=1;i<=n1;i++) a[i]=read(); for (int i=1;i<=n2;i++) m-=read()-1; ans=0; if (m>0) dfs(1,0,m); cout<<ans<<endl; } return 0; }