祭奠我逝去的一下午加一晚上=-=
Description
顺便 sto \(Zhang\_RQ\)学长 orz
Solution
这名字听起来就很高大上的样子(事实上确实如此
好吧其实我并没有打算推式子,因为 \(BJpers2\) 巨佬在他的题解中已经把式子推的很明白了,只是他的代码着实有些毒瘤,因此我这里只是想放下我自己的代码罢了。
简单说两句,推出特征多项式 \(p(x)\) 之后,就是要求 \(x^n \ \ mod \ \ p(x)\),然而我们这个 \(n\) 是 \(10^9\),没办法直接放到多项式里算,所以采用快速幂的思想。
快速幂过程
设最开始的多项式 \(t(x) = 1\),倍增往上跳,最大情况下 \(t\) 会是一个 \(k - 1\) 次的多项式乘一下就会变成 \(2 \times k - 2\) 次,然后去对 \(k\) 次的多项式 \(p(x)\) 取模。
次数体现在代码里的话,就是 \(Mod\) 函数中传的实参是 \(n << 1\)。
坑点
- 数组最好都开成局部变量,不然就各种错乱(我一开始用的全局变量数组就一直都是 0).
- 边界!边界!边界!好吧,说实话这玩意就算知道有坑点也没啥用,就算让我再写一遍可能也得调半天。
废话不多说,上代码吧,希望对您有帮助。
Code(大常数警告)
#include <bits/stdc++.h>
#define ll long long
using namespace std;
namespace IO{
inline ll read(){
ll x = 0, f = 1;
char ch = getchar();
while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x * f;
}
template <typename T> inline void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
inline void print(ll a[], ll n){
for(int i = 0; i <= n; ++i) printf("%lld ", a[i]);
puts("");
}
}
using namespace IO;
const ll N = 5e5 + 10;
const ll mod = 998244353;
const ll G = 3, Gi = 332748118;
ll n, m, k;
ll a[N], b[N], c[N], d[N], e[N], p[N], res[N], t[N];
ll f[N], g[N], ig[N], q[N], r[N];
namespace NTT{
ll lim, len;
inline ll qpow(ll a, ll b){
ll res = 1;
while(b){
if(b & 1) res = res * a % mod;
a = a * a % mod, b >>= 1;
}
return res;
}
inline void get_rev(ll n){
lim = 1, len = 0;
while(lim <= n) lim <<= 1, ++len;
for(int i = 0; i <= lim; ++i) p[i] = (p[i >> 1] >> 1) | ((i & 1) << (len - 1));
}
inline void ntt(ll A[], ll lim, ll type){
for(int i = 0; i <= lim; ++i)
if(i < p[i]) swap(A[i], A[p[i]]);
for(int mid = 1; mid < lim; mid <<= 1){
ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
for(int i = 0; i < lim; i += (mid << 1)){
ll w = 1;
for(int j = 0; j < mid; ++j, w = w * Wn % mod){
ll x = A[i + j], y = w * A[i + j + mid] % mod;
A[i + j] = (x + y) % mod;
A[i + j + mid] = (x - y + mod) % mod;
}
}
}
if(type == 1) return;
ll inv = qpow(lim, mod - 2);
for(int i = 0; i <= lim; ++i) A[i] = A[i] * inv % mod;
}
inline void Mul(ll n, ll m, ll a[], ll b[], bool flag = 1){
static ll d[N], e[N];
for(int i = 0; i < (n << 2); ++i) d[i] = e[i] = 0;
for(int i = 0; i < n; ++i) d[i] = a[i], e[i] = b[i];
get_rev(n + m);
ntt(d, lim, 1), ntt(e, lim, 1);
for(int i = 0; i < lim; ++i) d[i] = d[i] * e[i] % mod;
ntt(d, lim, -1);
for(int i = 0; i < (n << 1); ++i) a[i] = d[i];
for(int i = (n << 1); i <= lim; ++i) a[i] = 0;
if(flag) for(int i = n; i < (n << 1); ++i) a[i] = 0;
}
inline void Inv(ll n, ll a[], ll b[]){
if(!n) return b[0] = qpow(a[0], mod - 2), void();
Inv(n >> 1, a, b);
get_rev(n << 1);
for(int i = 0; i <= n; ++i) c[i] = a[i];
for(int i = n + 1; i <= lim; ++i) c[i] = 0;
ntt(c, lim, 1), ntt(b, lim, 1);
for(int i = 0; i < lim; ++i) b[i] = (2ll - c[i] * b[i] % mod + mod) * b[i] % mod;
ntt(b, lim, -1);
for(int i = n + 1; i <= lim; ++i) b[i] = 0;
}
}
using namespace NTT;
inline void Mod(ll n, ll m, ll f[], ll g[], ll r[]){
static ll a[N], b[N];
for(int i = 0; i < (n << 2); ++i) a[i] = b[i] = d[i] = 0;
for(int i = 0; i < n - m + 1; ++i) a[i] = f[n - i - 1];
for(int i = 0; i < n - m + 1; ++i) b[i] = g[m - i - 1];
Inv(n - m + 1, b, d);
Mul(n - m + 1, n - m + 1, a, d);
for(int i = 0; i <= n - m; ++i) q[i] = a[n - m - i];
for(int i = 0; i < (n << 2); ++i) a[i] = b[i] = 0;
for(int i = 0; i < n; ++i) a[i] = f[i];
for(int i = 0; i < m; ++i) b[i] = g[i];
Mul(n, n, b, q);
for(int i = 0; i < m - 1; ++i) r[i] = (a[i] - b[i] + mod) % mod;
for(int i = m - 1; i < lim; ++i) r[i] = 0;
}
inline void solve(ll p, ll n){
t[1] = res[0] = 1;
while(p){
if(p & 1) Mul(n, n, res, t, 0), Mod(n << 1, n, res, g, res);// b % g --> b
Mul(n, n, t, t, 0), Mod(n << 1, n, t, g, t);
p >>= 1;
}
}
signed main(){
// freopen("P4723.in", "r", stdin);
// freopen("P4723.out", "w", stdout);
n = read(), m = read();
g[0] = 1;
for(int i = 1; i <= m; ++i) g[i] = (mod - (read() % mod + mod) % mod);
reverse(g, g + 1 + m);
for(int i = 0; i < m; ++i) f[i] = read();
solve(n, m + 1);
ll ans = 0;
for(int i = 0; i < m; ++i) ans = (ans + res[i] * f[i] % mod + mod) % mod;
write(ans), puts("");
return 0;
}
\[\_EOF\_
\]