[LOJ 6485] LJJ 学二项式定理
题意
给定 (n,s,a_0,a_1,a_2,a_3), 求:
[Large left[ sum_{i=0}^n left( {nchoose i} cdot s^{i} cdot a_{imod 4}
ight)
ight] mod 998244353
]
(Tle 10^5) 组测试数据, (nle 10^{18};s,a_ile 10^9).
题解
一看 (n) 巨大无比显然不太能直接搞.
但是这个 (mod 4) 十分的玄妙, 我们尝试从它入手, 分别计算每个 (a_i) 所产生的贡献.
又因为组合数越界会变 (0), 于是答案可以写成这样:
[sum_{k=0}^3sum_{imod 4=k} {nchoose i}a_ks^i
]
(imod 4=k) 等价于 (4mid i-k). 于是我们有:
[sum_{k=0}^3sum_{4mid i-k} {nchoose i}a_ks^i
]
把下标换漂亮点并且把求和条件变成布尔表达式丢到和式里:
[sum_{k=0}^3sum_{i}[4mid i] {nchoose i+k}a_ks^{i+k}
]
然后我们如果能找个东西把 ([4mid i]) 反演掉就好了.
幸运的是在傅立叶变换中有个东西叫求和引理, 即当 (k ot mid n) 的时候有:
[sum_{j=0}^{n-1}(omega_n^k)^j=0
]
不难算出当 (kmid n) 的时候上式的值为 (n). 也就是说:
[frac 1 nsum_{j=0}^{n-1}(omega_n^k)^j=frac 1 nsum_{j=0}^{n-1}omega_n^{kj}=[kmid n]
]
那么我们就可以把这个东西代进去按照和式的套路搞一搞求和顺序和指标:
[egin{aligned}
&sum_{k=0}^3sum_{i}left (frac 1 4sum_{r=0}^3omega_4^{ir}
ight) {nchoose i+k}a_ks^{i+k}\
=&frac 1 4sum_{k=0}^3sum_{r=0}^3a_ksum_{i}{nchoose i+k}s^{i+k}omega_4^{ir} \
end{aligned}
]
那么问题变成了最内层的东西. 我们发现它有点类似二项式定理的形式, 但是 (omega_4^r) 上的指数不太对. 我们强行让它和二项式系数的部分一样:
[egin{aligned}
&frac 1 4sum_{k=0}^3sum_{r=0}^3a_ksum_{i}{nchoose i+k}s^{i+k}omega_4^{(i+k)r}omega_4^{-kr} \
=&frac 1 4sum_{k=0}^3sum_{r=0}^3a_komega_4^{-kr} sum_{i}{nchoose i+k}s^{i+k}omega_4^{(i+k)r}\
=&frac 1 4sum_{k=0}^3sum_{r=0}^3a_komega_4^{-kr} sum_{i}{nchoose i+k}(somega_4^r)^{i+k} \
=&frac 1 4sum_{k=0}^3sum_{r=0}^3a_komega_4^{-kr} (somega_4^r+1)^n
end{aligned}
]
然后就可以算了.
参考代码
其实 (omega_4) 就是虚数单位 (i)...
#include <bits/stdc++.h>
const int I=911660635;
const int NI=86583718;
const int MOD=998244353;
const int PHI=998244352;
const int INV4=748683265;
typedef long long intEx;
int Pow(int,int,int);
int main(){
int T;
scanf("%d",&T);
while(T--){
intEx n;
int s;
scanf("%lld%d",&n,&s);
int ans=0;
for(int i=0;i<4;i++){
int a;
scanf("%d",&a);
for(int j=0;j<4;j++)
(ans+=1ll*a*Pow(NI,i*j,MOD)%MOD*Pow((1ll*Pow(I,j,MOD)*s+1)%MOD,n%PHI,MOD)%MOD)%=MOD;
}
ans=1ll*ans*INV4%MOD;
printf("%d
",ans);
}
return 0;
}
inline int Pow(int a,int n,int p){
int ans=1;
while(n>0){
if(n&1)
ans=1ll*a*ans%p;
a=1ll*a*a%p;
n>>=1;
}
return ans;
}