P5293 [HNOI2019]白兔之舞(单位根反演)
题目大意
有一张顶点数为 ((L+1) imes n) 的有向图。这张图的每个顶点由一个二元组 ((u,v)) 表示 ((0le ule L,1le vle n))。这张图不是简单图,对于任意两个顶点 ((u_1,v_1),(u_2,v_2)),如果 (u_1<u_2),则从 ((u_1,v_1)) 到 ((u_2,v_2)) 一共有 (w(v_1,v_2)) 条不同的边,如果 (u_1ge u_2) 则没有边。
白兔将在这张图上上演一支舞曲。白兔初始时位于该有向图的顶点 ((0,x))。
白兔将会跳若干步。每一步,白兔会从当前顶点沿任意一条出边跳到下一个顶点。白兔可以在任意时候停止跳舞(也可以没有跳就直接结束)。当到达第一维为 (L) 的顶点就不得不停止,因为该顶点没有出边。
假设白兔停止时,跳了 (m) 步,白兔会把这只舞曲给记录下来成为一个序列。序列的第 (i) 个元素为它第 (i) 步经过的边。
问题来了:给定正整数 (k) 和 (y (1le yle n)),对于每个 (t (0le t<k)),求有多少种舞曲(假设其长度为 (m))满足 (m mod k=t),且白兔最后停在了坐标第二维为 (y) 的顶点?
两支舞曲不同定义为它们的长度((m))不同或者存在某一步它们所走的边不同。
输出的结果对 (p) 取模。
数据范围
对于全部数据,(p) 为一个质数,(10^8<p<2^{30}),(1le nle 3),(1le xle n),(1le yle n),(0le w(i,j)<p),(1le kle 65536),(k) 为 (p-1) 的约数,(1le Lle 10^8)。
对于每组测试点,特殊限制如下:
- 测试点 (1,2):(Lle 10^5);
- 测试点 (3):(n=1,w(1,1)=1),(k) 的最大质因子为 (2);
- 测试点 (4):(n=1),(k) 的最大质因子为 (2);
- 测试点 (5):(n=1,w(1,1)=1);
- 测试点 (6):(n=1);
- 测试点 (7,8):(k) 的最大质因子为 (2)。
解题思路
不妨先假设兔子是从第一个格子开始的,走到最后一列的第一个格子
那么如果给定了 m 该如何做呢
发现邻接矩阵都是一样的,那么我先用矩阵求出走 m 步的方案数,然后再用组合数求出 m 个落脚点即可
如果用矩阵表示是这样的
让我们暂时不看(mod k = t) 的条件,因为我们完全可以将多项式乘个 (x^{k-t}) 变成 (mod k = 0) 的情况
将式子写下
现在把 (mod k=t) 加上
考虑拆掉 (it) 项
因为 (frac {x^2}2) 和 (frac {P-1}{2n}) 不一定是整数,因此我们采用第二个式子
发现两个多项式的差相同,将第二个多项式反转,MTT 即可
#include <queue>
#include <vector>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#define MP make_pair
#define ll long long
#define fi first
#define se second
using namespace std;
template <typename T>
void read(T &x) {
x = 0; bool f = 0;
char c = getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
for (;isdigit(c);c=getchar()) x=x*10+(c^48);
if (f) x=-x;
}
template<typename F>
inline void write(F x, char ed = '
')
{
static short st[30];short tp=0;
if(x<0) putchar('-'),x=-x;
do st[++tp]=x%10,x/=10; while(x);
while(tp) putchar('0'|st[tp--]);
putchar(ed);
}
template <typename T>
inline void Mx(T &x, T y) { x < y && (x = y); }
template <typename T>
inline void Mn(T &x, T y) { x > y && (x = y); }
namespace MTT {
#define db double
#define op com operator
#define con const
const int N = 270500;
const db Pi = acos(-1.0);
int lim = 1, r[N], L, P;
struct com {
db x, y;
com(db a = 0, db b = 0) : x(a) , y(b) {}
op + (con com &w) con { return com(x + w.x, y + w.y); }
op - (con com &w) con { return com(x - w.x, y - w.y); }
op * (con com &w) con { return com(x * w.x - y * w.y, x * w.y + y * w.x); }
op * (con double w) con { return com(x * w, y * w); }
op - (void) con { return com(-x, -y); }
com conj() con { return com(x, -y); }
com mi() con { return com(-y, x); }
}A[N], B[N], C[N], D[N], E[N];
void dft(com *A) {
for (int i = 1;i < lim; i++)
if (r[i] > i) swap(A[i], A[r[i]]);
for (int i = 1;i < lim; i <<= 1) {
for (int j = 0;j < lim; j += (i << 1)) {
com *f = A + j, *g = f + i, *e = E + i;
for (int k = 0;k < i; k++) {
com x = f[k], y = g[k] * e[k];
f[k] = x + y, g[k] = x - y;
}
}
}
}
void idft(com *A) { dft(A), reverse(A + 1, A + lim); }
void init(int n) {
while (lim <= n + n) lim <<= 1, L++;
for (int i = 1;i < lim; i++)
r[i] = (r[i>>1]>>1) | ((i & 1) << (L - 1));
E[1] = com(1, 0);
for (int i = 2;i < lim; i <<= 1) {
com *e0 = E + i / 2, *e1 = E + i;
com w(cos(Pi / i), sin(Pi / i));
for (int j = 0;j < i; j += 2)
e1[j] = e0[j>>1], e1[j+1] = e1[j] * w;
}
}
void MTT(ll *f, ll *g, ll *ans, int n) {
init(n);
for (int i = 0;i <= n; i++) A[i] = com(f[i] >> 15, f[i] & 0x7fff);
for (int j = 0;j <= n; j++) C[j] = com(g[j] >> 15, g[j] & 0x7fff);
dft(A), dft(C);
for (int i = 0;i < lim; i++) {
int k = (i == 0) ? 0 : lim - i;
com a = (A[i] + A[k].conj()) * 0.5;
com b = (A[i] - A[k].conj()) * com(0, -0.5);
com c = (C[i] + C[k].conj()) * 0.5;
com d = (C[i] - C[k].conj()) * com(0, -0.5);
B[i] = a * c + a * d * com(0, 1), D[i] = b * c + b * d * com(0, 1);
}
idft(B), idft(D);
for (int i = 0;i <= n + n; i++) {
ll bx = (ll)(B[i].x / lim + 0.5) % P;
ll by = (ll)(B[i].y / lim + 0.5);
ll dx = (ll)(D[i].x / lim + 0.5);
ll dy = (ll)(D[i].y / lim + 0.5);
ans[i] = bx << 30;
ans[i] = (ans[i] + ((dx + by) << 15));
ans[i] = (ans[i] + dy) % P;
}
}
}
ll n, k, L, x, y, P;
struct Mat {
ll f[3][3];
Mat(void) { memset(f, 0, sizeof(f)); }
Mat operator * (Mat t) {
Mat tmp;
for (int i = 0;i < 3; i++)
for (int j = 0;j < 3; j++)
for (int k = 0;k < 3; k++)
tmp.f[i][j] = (tmp.f[i][j] + f[i][k] * t.f[k][j]) % P;
return tmp;
}
Mat operator * (ll x) {
Mat tmp;
for (int i = 0;i < 3; i++)
for (int j = 0 ;j < 3; j++)
tmp.f[i][j] = (f[i][j] * x) % P;
for (int i = 0;i < n; i++) tmp.f[i][i]++;
return tmp;
}
}M;
Mat Mi(Mat x, ll mi) {
Mat res; for (int i = 0;i < n; i++) res.f[i][i] = 1;
for (; mi; mi >>= 1, x = x * x)
if (mi & 1) res = res * x;
return res;
}
ll fpw(ll x, ll mi) {
ll res = 1; mi = mi % (P - 1);
mi += P - 1, mi %= (P - 1);
for (; mi; mi >>= 1, x = x * x % P)
if (mi & 1) res = res * x % P;
return res;
}
int st[500], tp;
ll getG(void) {
int x = P - 1;
for (int i = 2;i * i <= x; i++) {
if (x % i) continue;
st[++tp] = i;
while (x % i == 0) x /= i;
}
if (x != 1) st[++tp] = x;
for (int i = 2;i <= P; i++) {
int fl = 0;
for (int j = 1;j <= tp; j++)
if (fpw(i, (P - 1) / st[j]) == 1) {fl = 1; break;}
if (!fl) return fpw(i, (P - 1) / k);
}
return 0;
}
const int N = 300050;
ll f[N], g[N], h[N];
int main() {
read(n), read(k), read(L), read(x), read(y), read(P); x--, y--;
for (int i = 0;i < n; i++)
for (int j = 0;j < n; j++)
read(M.f[i][j]);
ll G = getG(); MTT::P = P;
for (int i = 0;i < k; i++) {
f[k-1-i] = Mi(M * fpw(G, i), L).f[x][y];
f[k-1-i] = f[k-1-i] * fpw(G, (ll)i * (i - 1) / 2) % P;
}
for (int i = 0;i < 2 * k; i++) {
g[i] = fpw(G, -(ll)i * (i - 1) / 2);
}
ll inv = fpw(k, P - 2);
MTT::MTT(f, g, h, 2 * k - 1);
for (int i = 0;i < k; i++)
write(fpw(G, (ll)i * (i - 1) / 2) * inv % P * h[i+k-1] % P);
return 0;
}