@description@
九条可怜是一个喜欢规律的女孩子。按照规律,第二题应该是一道和数据结构有关的题。
在一个遥远的国度,有 n 个城市。城市之间有 n - 1 条双向道路,这些道路保证了任何两个城市之间都能直接或者间接地到达。
在上古时代,这 n 个城市之间处于战争状态。在高度闭塞的环境中,每个城市都发展出了自己的语言。而在王国统一之后,语言不通给王国的发展带来了极大的阻碍。为了改善这种情况,国王下令设计了 m 种通用语,并进行了 m 次语言统一工作。在第 i 次统一工作中,一名大臣从城市 si 出发,沿着最短的路径走到了 ti,教会了沿途所有城市(包括 si, ti)使用第 i 个通用语。
一旦有了共通的语言,那么城市之间就可以开展贸易活动了。两个城市 ui, vi 之间可以开展贸易活动当且仅当存在一种通用语 L 满足 ui 到 vi 最短路上的所有城市(包括 ui, vi),都会使用 L。
为了衡量语言统一工作的效果,国王想让你计算有多少对城市 (u, v) (u < v),他们之间可以开展贸易活动。
@solution@
分为 u, v 有祖先关系;u, v 无祖先关系两类统计。
有祖先关系,不妨假设 u 是 v 的祖先。
只需求出路径的某一端在 v 的子树中,向上延伸深度最小为多少。
深度最小就是 lca,自下而上更新即可。
无祖先关系,不妨假设 dfs 序中 u 在 v 前面。
一样的,路径的某一端在 v 的子树,此时路径另一端 dfs 序中需要在 v 前面。
但是不同的路径可能会重复经过某一个点,导致重复统计。
假设所有路径另一端点的点集为 S,我们取 S + {v} 到根的链的并集,然后扣掉 {v} 到根的点数,就可以得到答案。
链并集就是个经典问题:∑端点深度 - ∑dfs序中相邻点lca的深度。然后可以用线段树合并维护链并集。
如果用 O(logn) 的 lca,则总时间复杂度为 O(nlog^2n)。
当然还要在线段树中去掉 dfs 序在 v 后面的点。
@accepted code@
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 100000;
struct edge{
int to; edge *nxt;
}edges[2*MAXN + 5], *adj[MAXN + 5], *ecnt = edges;
void addedge(int u, int v) {
edge *p = (++ecnt);
p->to = v, p->nxt = adj[u], adj[u] = p;
p = (++ecnt);
p->to = u, p->nxt = adj[v], adj[v] = p;
}
int fa[20][MAXN + 5], dep[MAXN + 5];
int dfn[MAXN + 5], tid[MAXN + 5], dcnt;
void dfs1(int x, int f) {
fa[0][x] = f;
for(int i=1;i<20;i++)
fa[i][x] = fa[i-1][fa[i-1][x]];
dep[x] = dep[f] + 1, dfn[++dcnt] = x, tid[x] = dcnt;
for(edge *p=adj[x];p;p=p->nxt)
if( p->to != f ) dfs1(p->to, x);
}
int lca(int u, int v) {
if( dep[u] < dep[v] ) swap(u, v);
for(int i=19;i>=0;i--)
if( dep[fa[i][u]] >= dep[v] )
u = fa[i][u];
if( u == v ) return u;
for(int i=19;i>=0;i--)
if( fa[i][u] != fa[i][v] )
u = fa[i][u], v = fa[i][v];
return fa[0][u];
}
int ch[2][20*MAXN + 5], lm[20*MAXN + 5], rm[20*MAXN + 5], ncnt;
ll sum[20*MAXN + 5]; int rt[20*MAXN + 5];
void pushup(int x) {
lm[x] = (lm[ch[0][x]] != -1 ? lm[ch[0][x]] : lm[ch[1][x]]);
rm[x] = (rm[ch[1][x]] != -1 ? rm[ch[1][x]] : rm[ch[0][x]]);
sum[x] = sum[ch[0][x]] + sum[ch[1][x]];
if( rm[ch[0][x]] != -1 && lm[ch[1][x]] != -1 )
sum[x] -= dep[lca(dfn[rm[ch[0][x]]], dfn[lm[ch[1][x]]])];
}
void update(int &x, int l, int r, int p, int d) {
if( !x ) x = (++ncnt), lm[x] = -1, rm[x] = -1;
if( l == r ) {
if( d == 1 ) {
if( lm[x] == -1 )
lm[x] = rm[x] = l, sum[x] = dep[dfn[l]];
} else if( d == -1 ) {
if( lm[x] != -1 )
lm[x] = rm[x] = -1, sum[x] = 0;
}
return ;
}
int m = (l + r) >> 1;
if( p <= m ) update(ch[0][x], l, m, p, d);
else update(ch[1][x], m + 1, r, p, d);
pushup(x);
}
int merge(int x, int y, int l, int r) {
if( !x || !y ) return x + y;
if( l == r ) {
if( lm[y] != -1 )
lm[x] = rm[x] = l, sum[x] = dep[dfn[l]];
return x;
}
int m = (l + r) >> 1;
ch[0][x] = merge(ch[0][x], ch[0][y], l, m);
ch[1][x] = merge(ch[1][x], ch[1][y], m + 1, r);
pushup(x); return x;
}
int n, m;
vector<int>v[MAXN + 5]; int mnd[MAXN + 5]; ll ans;
void dfs2(int x, int f) {
rt[x] = 0;
for(int i=0;i<v[x].size();i++)
update(rt[x], 1, n, tid[v[x][i]], 1);
for(edge *p=adj[x];p;p=p->nxt) {
if( p->to == f ) continue;
dfs2(p->to, x), mnd[x] = min(mnd[x], mnd[p->to]);
rt[x] = merge(rt[x], rt[p->to], 1, n);
}
ans += (dep[x] - mnd[x]);
while( rm[rt[x]] >= tid[x] )
update(rt[x], 1, n, rm[rt[x]], -1);
ans += sum[rt[x]];
if( rm[rt[x]] != -1 ) ans -= dep[lca(x, dfn[rm[rt[x]]])];
}
int main() {
// freopen("language.in", "r", stdin);
// freopen("language.out", "w", stdout);
scanf("%d%d", &n, &m);
for(int i=1,x,y;i<n;i++)
scanf("%d%d", &x, &y), addedge(x, y);
dfs1(1, 0);
for(int i=1;i<=n;i++) mnd[i] = dep[i];
for(int i=1,s,t,l;i<=m;i++) {
scanf("%d%d", &s, &t), l = lca(s, t);
mnd[s] = min(mnd[s], dep[l]), mnd[t] = min(mnd[t], dep[l]);
if( tid[s] > tid[t] ) swap(s, t); v[t].push_back(s);
}
lm[0] = rm[0] = -1, dfs2(1, 0), printf("%lld
", ans);
}
@details@
我竟然做出来一道ZJOI题?可能也只有这一道了吧
数组太多了可能会搞混,一定要区分清楚。
理论上最优复杂度是写 O(1) 的 lca,好在出题人没有卡(