You are given a weighted tree consisting of nn vertices. Recall that a tree is a connected graph without cycles. Vertices u_iu i and v_iv i are connected by an edge with weight w_iw i . You are given mm queries. The ii-th query is given as an integer q_iq i . In this query you need to calculate the number of pairs of vertices (u, v)(u,v) (u < vu<v) such that the maximum weight of an edge on a simple path between uu and vv doesn't exceed q_iq i . Input The first line of the input contains two integers nn and mm (1 \le n, m \le 2 \cdot 10^51≤n,m≤2⋅10 5 ) — the number of vertices in the tree and the number of queries. Each of the next n - 1n−1 lines describes an edge of the tree. Edge ii is denoted by three integers u_iu i , v_iv i and w_iw i — the labels of vertices it connects (1 \le u_i, v_i \le n1≤u i ,v i ≤n, u_i \ne v_iu i =v i ) and the weight of the edge (1 \le w_i \le 2 \cdot 10^51≤w i ≤2⋅10 5 ). It is guaranteed that the given edges form a tree. The last line of the input contains mm integers q_1, q_2, \dots, q_mq 1 ,q 2 ,…,q m (1 \le q_i \le 2 \cdot 10^51≤q i ≤2⋅10 5 ), where q_iq i is the maximum weight of an edge in the ii-th query. Output Print mm integers — the answers to the queries. The ii-th value should be equal to the number of pairs of vertices (u, v)(u,v) (u < vu<v) such that the maximum weight of an edge on a simple path between uu and vv doesn't exceed q_iq i . Queries are numbered from 11 to mm in the order of the input. Sample 1 Inputcopy Outputcopy 7 5 1 2 1 3 2 3 2 4 1 4 5 2 5 7 4 3 6 2 5 2 3 4 1 21 7 15 21 3 Sample 2 Inputcopy Outputcopy 1 2 1 2 0 0 Sample 3 Inputcopy Outputcopy 3 3 1 2 1 2 3 2 1 3 2
#include <bits/stdc++.h> using namespace std; #define ri register int #define M 200005 template <class G> void read(G &x) { x=0;int f=0;char ch=getchar(); while(ch<'0'||ch>'9'){f|=ch=='-';ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} x=f?-x:x; return ; } int fa[M]; int n,m; struct dian{ int u,v; int val; bool operator <(const dian &t)const { return val<t.val; } }p[M]; int find(int a) { if(a==fa[a]) return a; fa[a]=find(fa[a]); return fa[a]; } long long ans[M],num[M]; int main(){ read(n);read(m); for(ri i=1;i<n;i++) { int a,b,c; read(a);read(b);read(c); p[i].u=a;p[i].v=b;p[i].val=c; } sort(p+1,p+n); long long tmp=0; for(ri i=1;i<=n;i++) fa[i]=i,num[i]=1; for(ri i=1;i<n;i++) { int l=find(p[i].u);int r=find(p[i].v); fa[l]=r; long long ff=num[l]*num[r]; ans[p[i].val]=ff+tmp; tmp+=ff; num[r]+=num[l]; } for(ri i=1;i<=200000;i++) { if(ans[i]) continue; else ans[i]=ans[i-1]; } for(ri i=1;i<=m;i++) { int a; read(a); printf("%lld ",ans[a]); } return 0; }
思路:
- 把树分成很多个点,依据边的大小,一次连接,形成更大的块,连接时维护答案即可(并查集);