「2019 集训队互测 Day 1」最短路径 题解
算法标签: 分治,ntt。
这题主要考察了对于分治的应用。
首先考虑最简单的“树”的情况。很容易想到,可以点分治+卷积实现。
然后只剩下环的情况了。
设环上的第距离环上的第i个点距离为j的点的个数为([x^j]f_i)。
设环长为(len)。
我们将从任意一个位置破环成链。然后再复制一份贴到后面。
则我们要算出:(sum f_i imes f_j imes x^{j-i},(j>i,0leq i<len,j-igeq lfloor(len-1)/2 floor))。特殊的我们还需要处理一下偶数的情况。
然后我们可以将那个长度为len+len的东西分成4/5段,每段为(lfloor(len-1)/2 floor),则段内部一定满足(j-igeq lfloor(len-1)/2 floor)这个条件,所以直接分治算就可以了。
然后考虑两段之间的贡献,同样类似dp决策单调性那样分治就可以了。
code:
这个代码为了减小常数,采用带权分治。
/*
{
######################
# Author #
# Gary #
# 2021 #
######################
*/
#include <bits/stdc++.h>
#define rb(a,b,c) for(int a=b;a<=c;++a)
#define rl(a,b,c) for(int a=b;a>=c;--a)
#define LL long long
#define IT iterator
#define PB push_back
#define II(a,b) make_pair(a,b)
#define FIR first
#define SEC second
#define FREO freopen("check.out","w",stdout)
#define rep(a,b) for(int a=0;a<b;++a)
#define SRAND mt19937 rng(chrono::steady_clock::now().time_since_epoch().count())
#define random(a) rng()%a
#define ALL(a) a.begin(),a.end()
#define POB pop_back
#define ff fflush(stdout)
#define fastio ios::sync_with_stdio(false)
#define check_min(a,b) a=min(a,b)
#define check_max(a,b) a=max(a,b)
using namespace std;
const int INF = 0x3f3f3f3f;
typedef pair<int, int> mp;
inline int read() {
int x = 0;
char ch = getchar();
while (ch < '0' || ch > '9') {
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
return x;
}
// NTT template begins
int MOD = 998244353;
int g = 3;
int len;
int rev[1 << 19];
void butterfly(vector<int> &v) {
rep(i, len) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1)
rev[i] |= len >> 1;
}
rep(i, len) if (rev[i] > i)
swap(v[i], v[rev[i]]);
}
int quick(int A, int B) {
if (B == 0)
return 1;
int tmp = quick(A, B >> 1);
tmp = 1ll * tmp * tmp % MOD;
if (B & 1)
tmp = 1ll * tmp * A % MOD;
return tmp;
}
int inv(int x) {
return quick(x, MOD - 2);
}
vector<int> ntt(vector<int> v, int ty) {
for (auto &it : v) {
it %= MOD;
}
butterfly(v);
vector<int> nex;
for (int l = 2; l <= len; l <<= 1) {
nex.clear();
nex.resize(len);
int step = quick(g, (MOD - 1) / l);
if (ty == -1)
step = inv(step);
for (int j = 0; j < len; j += l) {
int now = 1;
for (int k = 0; k < l / 2; ++k) {
int A, B;
A = v[j + k];
B = v[j + l / 2 + k];
B = 1ll * now * B % MOD;
nex[j + k] = (A + B) % MOD;
nex[j + k + l / 2] = (A - B + MOD) % MOD;
now = 1ll * now * step % MOD;
}
}
v = nex;
}
return v;
}
void getlen(int x) {
len = 1;
while (len < x) {
len <<= 1;
}
}
vector<int> mul(vector<int> A, vector<int> B) {
getlen(A.size() + B.size());
A.resize(len);
B.resize(len);
A = ntt(A, 1);
B = ntt(B, 1);
rep(i, len) A[i] = 1ll * A[i] * B[i] % MOD;
A = ntt(A, -1);
int iv = inv(len);
rep(i, len) {
A[i] = 1ll * A[i] * iv % MOD;
}
while (!A.empty() && A.back() == 0)
A.pop_back();
return A;
}
void add(vector<int> &A, vector<int> B) {
if (A.size() < B.size())
A.resize(B.size());
rep(i, B.size()) {
(A[i] += B[i]) %= MOD;
}
}
vector<int> right_shift(vector<int> A, int x) {
reverse(ALL(A));
rb(i, 1, x) A.PB(0);
reverse(ALL(A));
return A;
}
//NTT template ends
const int MAXN = 1e5 + 233;
int on_cycle[MAXN];
int n, k;
vector<int> gra[MAXN];
int siz[MAXN];
int sz = 0;
int anslen;
int get_centroid(int now, int fa = -1) {
siz[now] = 1;
int w = -INF;
for (auto it : gra[now])
if (it != fa && !on_cycle[it]) {
int ret = get_centroid(it, now);
if (ret)
return ret;
siz[now] += siz[it];
check_max(w, siz[it]);
}
check_max(w, sz - siz[now]);
if (w <= sz / 2 + 3) {
return now;
}
return 0;
}
vector<int> cycle;
bool ok = 0;
bool vis[MAXN];
stack<int> sta;
void findcycle(int now, int pre = -1) {
if (cycle.size())
return;
vis[now] = true;
sta.push(now);
for (auto it : gra[now])
if (it != pre) {
if (vis[it]) {
int Now;
do {
Now = sta.top(), sta.pop();
cycle.PB(Now);
} while (Now != it);
return ;
}
findcycle(it, now);
if (cycle.size())
return;
}
sta.pop();
}
vector<int> f[MAXN * 2];
vector<int> ret;
void calc(vector<int> &v, int now, int depth = 0, int pre = -1) {
if (v.size() <= depth)
v.resize(depth + 1);
v[depth]++;
for (auto it : gra[now])
if (it != pre && !on_cycle[it]) {
calc(v, it, depth + 1, now);
}
}
void getsize(int now, int pre = -1) {
sz++;
for (auto it : gra[now])
if (!on_cycle[it] && it != pre)
getsize(it, now);
}
void get(int now) {
sz = 0;
getsize(now);
now = get_centroid(now);
bool pre = on_cycle[now];
on_cycle[now] = 1;
vector<int> presum(1, 1);
for (auto it : gra[now])
if (!on_cycle[it]) {
get(it);
vector<int> tmp;
calc(tmp, it);
add(ret, mul(presum, right_shift(tmp, 1)));
add(presum, right_shift(tmp, 1));
}
on_cycle[now] = pre;
}
int to;
void solve(int ansl, int ansr, int l, int r) {
check_min(ansr, anslen);
if (ansl >= ansr || l >= r)
return;
vector<int> lp, rp;
rb(i, ansl, ansr - 1) {
int st = ansr - 1 - i;
if (st + f[i].size() > lp.size())
lp.resize(st + f[i].size());
rep(j, f[i].size()) {
(lp[j + st] += f[i][j]) %= MOD;
}
}
rb(i, l, r - 1) {
int st = i - l;
if (st + f[i].size() > rp.size())
rp.resize(st + f[i].size());
rep(j, f[i].size()) {
(rp[j + st] += f[i][j]) %= MOD;
}
}
vector<int> tmp = mul(lp, rp);
int gap = l - ansr + 1;
rep(i, tmp.size()) {
(ret[i + gap] += tmp[i]) %= MOD;
}
}
void divc(int ansl, int ansr, int l, int r) {
if (l == r - 1) {
solve(l - to, ansr, l, r);
return ;
}
int mid = (l + r) >> 1;
int ansmid = mid - to;
divc(ansmid, ansr, mid, r);
divc(ansl, ansmid, l, mid);
solve(ansmid, ansr, l, mid);
}
void div1(int l, int r) {
if (l == r - 1)
return ;
mp best = {INF, INF};
int tot = 0;
rb(i, l, r - 1) tot += f[i].size();
tot /= 2;
rb(i, l, r - 1) {
tot -= f[i].size();
check_min(best, II(abs(tot), i));
}
int mid = best.second + 1;
solve(l, mid, mid, r);
div1(l, mid);
div1(mid, r);
}
main() {
// freopen("sub35.in","r",stdin);
// scanf("%d%d",&n,&k);
n = read();
k = read();
rb(i, 1, n) {
int u, v;
// scanf("%d%d",&u,&v);
u = read();
v = read();
if (u == v) {
ok = true;
continue;
}
gra[u].PB(v), gra[v].PB(u);
}
int rest = 0;
ret.resize(n + 1);
if (ok) {
get(1);
} else {
findcycle(1);
for (auto it : cycle)
on_cycle[it] = true;
int now = 0;
for (auto it : cycle) {
on_cycle[it] = false;
calc(f[now++], it), get(it);
on_cycle[it] = true;
}
anslen = cycle.size();
cycle.resize(anslen + anslen);
rep(i, anslen) cycle[i + anslen] = cycle[i], f[i + anslen] = f[i];
to = anslen / 2;
if (anslen & 1);
else {
to--;
vector<int> tmp;
rep(i, anslen) if (i < (i + to + 1) % anslen)
add(tmp, mul(f[i], f[i + to + 1]));
tmp = right_shift(tmp, to + 1);
add(ret, tmp);
}
if (to) {
vector<mp> each;
int now = 0;
while (now < anslen + anslen) {
int nex = min(anslen + anslen, now + to);
each.PB(II(now, nex));
div1(now, nex);
now = nex;
}
rb(i, 1, each.size() - 1) {
divc(each[i - 1].FIR, each[i - 1].SEC, each[i].FIR, each[i].SEC);
}
}
}
rb(i, 0, n) {
rest += 1ll * ret[i] * quick(i, k) % MOD;
if (rest >= MOD)
rest -= MOD;
}
rest = 1ll * rest * inv(1ll * n * (n - 1) / 2 % MOD) % MOD;
cout << rest << endl;
return 0;
}