这题其实一眼就知道肯定要容斥了,分为行列单独容斥,最后交叉 的时候容斥一下就有70分了(暴力容斥)
70分题解:
code:
1 #include<iostream> 2 #include<cstdio> 3 #define N 5000006 4 using namespace std; 5 const long long mod=998244353; 6 long long n,k; 7 long long jie[N],ci[N],c[3001][3001]; 8 void pre() { 9 ci[0]=1; 10 for(int i=1; i<=n*n; i++)ci[i]=ci[i-1]*k%mod; 11 for (int i=0; i<=n; i++) 12 for (int j=0; j<=i; j++) 13 if (j==0||i==j) c[i][j]=1; 14 else c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod; 15 } 16 int main() { 17 // freopen("magic.in","r",stdin); 18 // freopen("magic.out","w",stdout); 19 cin>>n>>k; 20 pre(); 21 long long ans=0; 22 for(long long i=1; i<=n; i++) { 23 if(i&1) { 24 ans+=((ci[(n-i)*n]*c[n][i])%mod*ci[i])%mod; 25 ans+=mod; 26 ans%=mod; 27 } else { 28 ans-=((ci[(n-i)*n]*c[n][i])%mod*ci[i])%mod; 29 ans+=mod; 30 ans%=mod; 31 } 32 } 33 ans*=2; 34 for(long long i=1; i<=n; i++) { 35 for(long long j=1; j<=n; j++) { 36 if((i+j)&1) { 37 ans+=(ci[((n-(i+j))*n)+i*j]*c[n][i])%mod*c[n][j]%mod*k%mod; 38 ans%=mod; 39 } else { 40 ans-=(ci[((n-(i+j))*n)+i*j]*c[n][i])%mod*c[n][j]%mod*k%mod; 41 ans=(ans+mod)%mod; 42 } 43 } 44 } 45 cout<<ans; 46 return 0; 47 }
100分:
官方code:
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<cmath> 5 #include<algorithm> 6 #include<set> 7 #include<queue> 8 #include<ctime> 9 #define MAXN 200005 10 #define ll long long 11 #define maxn 15 12 #define maxs 1000005 13 #define inf 1e9 14 #define eps 1e-9 15 using namespace std; 16 inline char gc() { 17 static char now[1<<16],*S,*T; 18 if (T==S) { 19 T=(S=now)+fread(now,1,1<<16,stdin); 20 if (T==S) return EOF; 21 } 22 return *S++; 23 } 24 inline ll readlong() { 25 ll x=0,f=1; 26 char ch=getchar(); 27 while(ch<'0'||ch>'9') { 28 if(ch=='-')f=-1; 29 ch=getchar(); 30 } 31 while(ch>='0'&&ch<='9') { 32 x*=10; 33 x+=ch-'0'; 34 ch=getchar(); 35 } 36 return x*f; 37 } 38 const int N=1000005; 39 const int mod=998244353; 40 ll res,ans,n,k; 41 ll fac[N],inv[N]; 42 void update(ll &x,ll y) { 43 x+=y; 44 if(x<0) { 45 x+=mod; 46 } 47 if(x>=mod) { 48 x-=mod; 49 } 50 } 51 ll ksm(ll x,ll k) { 52 update(x,0); 53 ll ret=1; 54 ll ans=x; 55 while(k) { 56 if(k&1) { 57 ret=ret*ans%mod; 58 } 59 ans=ans*ans%mod; 60 k>>=1; 61 } 62 return ret; 63 } 64 ll calc(ll x,ll y) { 65 if(x<y) { 66 return 0; 67 } 68 if(x==y) { 69 return 1; 70 } 71 return 1ll*fac[x]*inv[y]%mod*inv[x-y]%mod; 72 } 73 int main() { 74 freopen("magic.in","r",stdin); 75 freopen("magic.out","w",stdout); 76 n=readlong(); 77 k=readlong(); 78 fac[0]=1; 79 for(int i=1; i<=n; i++)fac[i]=fac[i-1]*i%mod; 80 inv[n]=ksm(fac[n],mod-2); 81 for(int i=n-1; i>=0; i--)inv[i]=inv[i+1]*(i+1)%mod; 82 for(int i=1; i<=n; i++) { 83 int a=1ll*calc(n,i)*ksm(-1,i+1)%mod; 84 int x=ksm(k,(1ll*n*(n-i)+i)%(mod-1)); 85 update(ans,1ll*a*x%mod); 86 } 87 ans=2*ans%mod; 88 for(int i=0; i<n; i++) { 89 int tmp=mod-ksm(k,i); 90 int x=(ksm(tmp+1,n)+mod-ksm(tmp,n))%mod; 91 int a=1ll*calc(n,i)*ksm(-1,i+1)%mod; 92 update(res,1ll*a*x%mod); 93 } 94 res=k*res%mod; 95 printf("%lld ",(ans+res)%mod); 96 return 0; 97 }
本人code:
1 #include<iostream> 2 #include<cstdio> 3 #define N 1000005 4 using namespace std; 5 const long long mod=998244353; 6 long long n,k; 7 long long jie[N],inv[N]; 8 long long ksm(long long a,long long b) { 9 long long ans=1; 10 for(; b; b>>=1) { 11 if(b&1) { 12 ans*=a; 13 ans%=mod; 14 } 15 a*=a; 16 a%=mod; 17 } 18 return ans; 19 } 20 long long read() { 21 long long x=0,f=1; 22 char c=getchar(); 23 while(!isdigit(c)) { 24 if(c=='-')f=-1; 25 c=getchar(); 26 } 27 while(isdigit(c)) { 28 x=(x<<3)+(x<<1)+c-'0'; 29 c=getchar(); 30 } 31 return x*f; 32 } 33 long long C(long long a,long long b) { 34 return (((jie[a]*inv[b])%mod*inv[a-b])%mod+mod)%mod; 35 } 36 void pre() { 37 jie[0]=1; 38 for(long long i=1; i<=n; i++)jie[i]=jie[i-1]*i%mod; 39 inv[n]=ksm(jie[n],mod-2); 40 for(long long i=n-1; i>=0; i--)inv[i]=inv[i+1]*(i+1)%mod; 41 } 42 int main() { 43 n=read(),k=read(); 44 pre(); 45 long long ans=0; 46 for(long long i=1; i<=n; i++) { 47 if(i&1) { 48 ans+=C(n,i)*ksm(k,(n-i)*n)%mod*ksm(k,i); 49 ans+=mod; 50 ans%=mod; 51 } else { 52 ans-=C(n,i)*ksm(k,(n-i)*n)%mod*ksm(k,i); 53 ans+=mod; 54 ans%=mod; 55 } 56 } 57 ans*=2; 58 ans%=mod; 59 for(long long i=0; i<n; i++) { 60 if(i&1) { 61 long long temp1=C(n,i); 62 long long temp2=(1-ksm(k,i)+2*mod)%mod; 63 long long temp3=ksm(k,i); 64 temp2=ksm(temp2,n); 65 temp3=ksm(temp3,n); 66 if(n&1){ 67 temp3=-temp3; 68 } 69 ans+=k*(temp1*(temp2-temp3)%mod+mod)%mod; 70 ans+=mod; 71 ans%=mod; 72 } else { 73 long long temp1=C(n,i); 74 long long temp2=(1-ksm(k,i)+2*mod)%mod; 75 long long temp3=ksm(k,i); 76 temp2=ksm(temp2,n); 77 temp3=ksm(temp3,n); 78 if(n&1){ 79 temp3=-temp3; 80 } 81 ans-=k*(temp1*(temp2-temp3)%mod+mod)%mod; 82 ans+=mod; 83 ans%=mod; 84 } 85 } 86 cout<<ans; 87 return 0; 88 }
over