时间限制:1000ms
单点时限:1000ms
内存限制:256MB
描述
现在有一棵有N个带权顶点的树,顶点编号为1,2,...,N。我们定义一条路径的次小(最小)权为它经过的所有顶点(包括起点和终点)中权值次小(最小)顶点的权值。现在给定常数c,你需要求出:存在多少个使得u<v的顶点组(u,v),满足从u到v的最短路的次小权恰为c但最小权不为c。
输入
第一行有两个数N和c。(1<=n<=100000)
第二行N个数,依次表示每个顶点的权值。
接下来N-1行,每行两个数,代表这棵树的一条边所连接的两个顶点的编号。
我们保证输入中的数都在int以内。
输出
一个数,为答案。
样例输入
8 2
2 2 3 3 1 2 3 2
1 2
3 2
3 8
4 2
5 2
5 6
6 7
样例输出
17
Solution
为了方便, 把我们要考虑的树记作$T=(V, E)$, 用$w[u]$表示节点$u$ ($uin V$) 的权值.
先考虑一个简化的问题:
求最小权小于$c$且次小权不小于$c$的路径$(u, v)$的数目.
为了解决这个问题, 我们考虑如下的添边过程:
我们考虑一个动态的图$S(V, E'), E'subseteq E$.
从$S=(V, emptyset)$开始, 先把所有满足$w[u]ge c land w[v] ge c$的边$(u, v)$加到$S$中,
然后考虑满足
[w[u]<c land w[v]ge c lor w[u]ge c land w[v] <c]
的边$(u, v)$, 不失一般性, 不妨设 $w[u]<c, w[v]ge c$.
我们先把$u$固定为$u_0$, 考虑若将所有符合上述条件的边${(u_0, v)}$加到$s$中将能获得多少满足条件的路径.
显然这些满足条件的路径上的最小权就是$w[u_0]$.
(未完待续...)
(无力写了, 先把代码贴上)
UPD
前面写得太罗嗦了, 结果现在自己都看不大懂了. 其实做法一句话就能说清楚:
最小权小于$c$, 次小权不小于$c$的路径数 $-$ 最小权小于$c$, 次小权大于$c$的路径数
Implementation
1 #include <bits/stdc++.h> 2 using namespace std; 3 using LL=long long; 4 const int N{1<<17}; 5 6 int a[N]; 7 8 struct edge{ 9 int u, v; 10 void read(){ 11 cin>>u>>v; 12 } 13 }e[N]; 14 15 struct DSU{ 16 int par[N], size[N]; 17 int n; 18 DSU(int n):n(n){} 19 void init(){ 20 for(int i=1; i<=n; i++){ 21 par[i]=i; 22 size[i]=1; 23 } 24 } 25 int find(int x){ 26 return x==par[x]?x:par[x]=find(par[x]); 27 } 28 void unite(int x, int y){ 29 x=find(x), y=find(y); 30 if(x!=y) par[x]=y, size[y]+=size[x]; 31 } 32 }; 33 34 vector<int> f[N]; 35 36 void prep(DSU &b, int n, int c){ 37 b.init(); 38 for(int i=1; i<=n; i++) f[i].clear(); 39 for(int i=1; i<n; i++){ 40 int u=e[i].u, v=e[i].v; 41 if(a[u]>=c && a[v]>=c){ 42 b.unite(u, v); 43 } 44 } 45 } 46 47 int main(){ 48 int n, c; 49 cin>>n>>c; 50 DSU b(n); 51 52 for(int i=1; i<=n; i++) 53 cin>>a[i]; 54 for(int i=1; i<n; i++) 55 e[i].read(); 56 57 58 LL res=0; 59 60 prep(b, n, c); 61 62 for(int i=1; i<n; i++){ 63 int u=e[i].u, v=e[i].v; 64 if(a[u]<c ^ a[v]<c){ //tricky 65 // cout<<u<<' '<<v<<endl; 66 if(a[v]<c) swap(u, v); 67 int rv=b.find(v); 68 // res+=LL(b.size[u])*LL(b.size[v]); 69 // if(ru!=rv) 70 f[u].push_back(b.size[rv]); 71 } 72 } 73 74 for(int i=1; i<=n; i++){ 75 // if(f[i].size()) cout<<"#"<<i<<endl; 76 LL sum=0, t=0; 77 for(auto &x: f[i]) 78 sum+=x; 79 for(auto &x: f[i]) t+=LL(x)*(sum-x); 80 res+=t>>1; 81 res+=sum; 82 } 83 84 85 prep(b, n, c+1); 86 87 for(int i=1; i<n; i++){ 88 int u=e[i].u, v=e[i].v; 89 if(a[u]<c && a[v]>c || a[u]>c && a[v]<c){ //tricky 90 if(a[v]<c) swap(u, v); 91 int rv=b.find(v); 92 // res+=LL(b.size[u])*LL(b.size[v]); 93 f[u].push_back(b.size[rv]); 94 } 95 } 96 97 for(int i=1; i<=n; i++){ 98 LL sum=0, t=0; 99 for(auto &x: f[i]) 100 sum+=x; 101 for(auto &x: f[i]) t+=LL(x)*(sum-x); 102 res-=t>>1, res-=sum; 103 } 104 105 cout<<res<<endl; 106 }