题目大意:给点一颗包含 (n)个节点的无根树,有 (m)次询问,每次询问给出两个点 (u)和 (v),要求计算
(d_{r}(u,v))是以 (r)为根的树上 (u)到 (v)的“美丽路径”,它的定义为:
其中 (lca_{r}(u,v))是以节点 (r)为根的树中,点 (u)和点 (v)的最近公共祖先。(dis(u,v))等于 (u),(v)之间最短路径的边数。
输入:第一行输入 (n,m),接下来 (n-1)行给出连边情况,接下来 (m)行代表 (m)组询问。
输出:对于每个询问输出答案对998244353取模
数据范围:(1 leq n,m leq 1e5)
分析:令节点 (1)为根,简化问题。考虑要算的东西,发现它只与 (u)到 (v)的路径上的节点以及这些节点的“分支节点”有关。不明白的话可以画图具体算一下。考虑点 (u)到 (lca)上的节点 (u_{1},u_{2}...u_{k}),假设 (u_{p})为 (u_{k})的“分支节点”,那么无论是以 (u_{k})为根还是以 (u_{p})为根, (lca(u,v))都等于 (u_{k}),也就是说可以把 (u_{k})的“分支节点”对答案的贡献累加到 (u_{k})上。假设原本 (u_{k})对答案的贡献为 (w),那么现在就等于 ((num+1) cdot w),(num)是“分支节点”的个数,设 (siz[x])是以 (1)为根的树中以 (x)为根的子树大小,那么 (num=siz[u_{k}]-siz[u_{k-1}]),设 (u,v)之间的距离为 (dis),(dis=dep[u]+dep[v]-2 imes dep[lca])。那么 (u)到 (lca)上的节点 (u_{1},u_{2}...u_{k})对答案的贡献就等于
把它拆成 (8)项,分别计算就好,求下前缀和就可以 (O(1))计算。对于 (v)到 (lca)的那部分贡献同理计算。另外 (lca)对答案的贡献需要另算。
#include<cstdio>
typedef long long ll;
const int N = 1e5 + 5;
const int mod = 998244353;
int n, m, cnt, son_u, son_v;
int head[N], dep[N], son[N], fa[N], top[N];
ll d_siz[N], d2_siz[N], fa_d_siz[N], fa_d2_siz[N], siz[N];
// son_u表示u到lca路径上离lca最近的点,son_v同理
// d_siz[x] = dep[x] * siz[x]
// d2_siz[x] = dep[x] * dep[x] * siz[x]
// fa_d_siz[x] = dep[fa[x]] * siz[x]
// da_d2_siz[x] = dep[fa[x]] * dep[fa[x]] * siz[x]
struct Edge{
int nex, to;
}e[N << 1];
inline ll max(ll a, ll b) { return a > b ? a : b; }
inline void add(int a, int b) { e[++cnt] = {head[a], b}; head[a] = cnt; }
void dfs1(int u, int f){
dep[u] = dep[f] + 1, fa[u] = f, siz[u] = 1;
for(int i = head[u]; i; i = e[i].nex){
int to = e[i].to;
if(to == f) continue;
dfs1(to, u);
if(siz[to] > siz[son[u]]) son[u] = to;
siz[u] += siz[to];
}
}
void dfs2(int u, int ttop){
top[u] = ttop;
if(son[u]) dfs2(son[u], ttop);
for(int i = head[u]; i; i = e[i].nex){
int to = e[i].to;
if(to == fa[u] || to == son[u]) continue;
dfs2(to, to);
}
}
void dfs3(int u, int f){
d_siz[u] = (1LL * dep[u] * siz[u] + d_siz[f]) % mod;
d2_siz[u] = (1LL * dep[u] * dep[u] % mod * siz[u] + d2_siz[f]) % mod;
fa_d_siz[u] = (1LL * dep[fa[u]] * siz[u] + fa_d_siz[f]) % mod;
fa_d2_siz[u] = (1LL * dep[fa[u]] * dep[fa[u]] % mod * siz[u] + fa_d2_siz[f]) % mod;
for(int i = head[u]; i; i = e[i].nex){
int to = e[i].to;
if(to == f) continue;
dfs3(to, u);
}
}
// 找lca和son_u,son_v
int get_lca(int u, int v){
while(top[u] != top[v]){
if(dep[top[u]] > dep[top[v]]) son_u = top[u], u = fa[top[u]];
else son_v = top[v], v = fa[top[v]];
}
if(dep[u] > dep[v]) son_u = son[v];
else son_v = son[u];
return dep[u] > dep[v] ? v : u;
}
ll cal(int u, int v, ll *p){
// u或v等于0说明son_u不存在,返回0
return (dep[u] < dep[v] || v == 0 || u == 0) ? 0 : p[u] - p[v];
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 1, u, v; i < n; ++i){
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
dfs1(1, 0);
dfs2(1, 1);
dfs3(1, 0);
for(int i = 1, u, v; i <= m; ++i){
scanf("%d%d", &u, &v);
ll lca = get_lca(u, v), dis = dep[u] + dep[v] - (dep[lca] << 1);
if(u == lca) son_u = 0;
if(v == lca) son_v = 0;
ll ans = 1LL * (n - siz[son_u] - siz[son_v]) * (dep[u] - dep[lca]) % mod * (dep[v] - dep[lca]) % mod; // lca的贡献
ans -= ((dis - dep[u]) * cal(fa[u], lca, d_siz) + (dis - dep[v]) * cal(fa[v], lca, d_siz)) % mod;
ans += ((dis - dep[u]) * dep[u] % mod * max(0, siz[son_u] - siz[u]) + (dis - dep[v]) * dep[v] % mod * max(0, siz[son_v] - siz[v])) % mod;
ans += ((dis - dep[u]) * cal(u, son_u, fa_d_siz) + (dis - dep[v]) * cal(v, son_v, fa_d_siz)) % mod;
ans -= cal(fa[u], lca, d2_siz) + cal(fa[v], lca, d2_siz);
ans += cal(u, son_u, fa_d2_siz) + cal(v, son_v, fa_d2_siz);
ans += (dep[u] * cal(fa[u], lca, d_siz) + dep[v] * cal(fa[v], lca, d_siz)) % mod;
ans -= (dep[u] * cal(u, son_u, fa_d_siz) + dep[v] * cal(v, son_v, fa_d_siz)) % mod;
printf("%lld
", (ans % mod + mod) % mod);
}
return 0;
}