题意
给定一棵有n个节点的无根树,树上的每个点有一个非负整数点权。定义一条路径的价值为路径上的点权和-路径上的点权最大值。 给定参数P,我!=们想知道,有多少不同的树上简单路径,满足它的价值恰好是P的倍数。 注意:单点算作一条路径;u!=v时,(u,v)和(v,u)只算一次。
题解
树上路径统计,解法是点分治。点分的时候求出根到每个点路径最大值和权值和。排一序,然后开个桶,就能计算了。去重就套路的减去没棵子树里面的答案。
CODE
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
typedef long long LL;
LL ans;
int n, mod, fir[MAXN], nxt[MAXN<<1], to[MAXN<<1], cnt, val[MAXN];
inline void link(int x, int y) {
to[++cnt] = y; nxt[cnt] = fir[x]; fir[x] = cnt;
to[++cnt] = x; nxt[cnt] = fir[y]; fir[y] = cnt;
}
bool ban[MAXN];
int getsz(int u, int ff) {
int re = 1;
for(int v, i = fir[u]; i; i = nxt[i])
if((v=to[i]) != ff && !ban[v])
re += getsz(v, u);
return re;
}
int getrt(int u, int ff, int &rt, int Size) {
int re = 1; bool can = 1;
for(int v, tmp, i = fir[u]; i; i = nxt[i])
if((v=to[i]) != ff && !ban[v]) {
re += (tmp = getrt(v, u, rt, Size));
if((tmp<<1) > Size) can = 0;
}
if(((Size-re)<<1) > Size) can = 0;
if(can) rt = u;
return re;
}
struct node {
int mx, v;
inline bool operator <(const node &o)const {
return mx < o.mx;
}
}seq[MAXN], vv[MAXN];
int tot;
void dfs(int u, int ff, int mx, int vs) {
vs = (vs + val[u]) % mod;
mx = max(mx, val[u]);
vv[u] = (node){ mx, vs };
for(int v, i = fir[u]; i; i = nxt[i])
if((v=to[i]) != ff && !ban[v])
dfs(v, u, mx, vs);
}
void push(int u, int ff) {
seq[++tot] = vv[u];
for(int v, i = fir[u]; i; i = nxt[i])
if((v=to[i]) != ff && !ban[v])
push(v, u);
}
int bin[10000005];
LL calc(int rt, int o) {
tot = 0; push(rt, 0);
sort(seq + 1, seq + tot + 1);
LL re = 0;
for(int i = 1; i <= tot; ++i) {
re += bin[((seq[i].mx+o-seq[i].v)%mod+mod)%mod];
++bin[seq[i].v%mod];
}
for(int i = 1; i <= tot; ++i) --bin[seq[i].v%mod];
return re;
}
void solve(int x) {
dfs(x, 0, 0, 0);
ans += calc(x, val[x]);
ban[x] = 1;
for(int v, i = fir[x]; i; i = nxt[i])
if(!ban[v=to[i]]) ans -= calc(v, val[x]);
}
void TDC(int x) {
int Size = getsz(x, 0);
getrt(x, 0, x, Size);
solve(x);
for(int v, i = fir[x]; i; i = nxt[i])
if(!ban[v=to[i]]) TDC(v);
}
int main () {
scanf("%d%d", &n, &mod);
for(int i = 1, x, y; i < n; ++i)
scanf("%d%d", &x, &y), link(x, y);
for(int i = 1; i <= n; ++i) scanf("%d", &val[i]);
TDC(1);
printf("%lld
", ans+n);
}