Autostrady
https://szkopul.edu.pl/problemset/problem/f2dSBM7JteWHqtmVejMWe1bW/site/?key=statement
题意:
首先给定一棵树,除了n-1条树边以外,还有m条非树边。每次询问两个点的满足以下条件的路径条数。
- 不能走树上u到v的简单路径的边。
- 只能走一条非树边。
分析:
RMQ求LCA + 线段树合并。
问题转化为有多少边的一个端点在u的子树内,另一个在v的子树内。
每个询问只在深度大的询问加入深度小的。对每个节点建立一个权值线段树,dfs从叶子节点往上合并,每到一个节点询问一段区间的数。
如果询问一个是另一个的祖先,要特判。
代码:
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<cmath> 5 #include<iostream> 6 #include<cctype> 7 #include<set> 8 #include<vector> 9 #include<queue> 10 #include<map> 11 #define pa pair<int,int> 12 #define mp(a,b) make_pair(a,b) 13 using namespace std; 14 typedef long long LL; 15 16 inline int read() { 17 int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1; 18 for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';return x*f; 19 } 20 21 const int N = 100005; 22 23 int deth[N], id[N], siz[N], ans[N * 5], Time_Index, n; // ans[N * 5] !!! 24 vector<int> adj[N], ext[N]; 25 vector< pa > q[N]; 26 27 namespace LCA{ 28 int ord[N << 1], d[N << 1], p[N], f[N << 1][20], Log[N << 1], tot = 0; 29 void dfs(int u,int fa,int dep) { 30 ord[++ tot] = u; d[tot] = dep; 31 p[u] = tot; 32 for (int i=0,sz=adj[u].size(); i<sz; ++i) { 33 int v = adj[u][i]; 34 if (v == fa) continue; 35 dfs(v, u, dep + 1); 36 ord[++tot] = u; d[tot] = dep; 37 } 38 } 39 void init() { 40 Log[0] = -1; 41 for (int i=1; i<=tot; ++i) 42 f[i][0] = i, Log[i] = Log[i >> 1] + 1; // i<<1 !!! 43 for (int j=1; j<=Log[tot]; ++j) { 44 for (int i=1; (i+(1<<j)-1)<=tot; ++i) { 45 int x = f[i][j - 1], y = f[i+(1<<(j-1))][j - 1]; // j-1 !!! 46 f[i][j] = d[x] < d[y] ? x : y; 47 } 48 } 49 } 50 int Lca(int u,int v) { 51 u = p[u], v = p[v]; 52 if (u > v) swap(u, v); 53 int k = Log[v - u + 1]; 54 int x = f[u][k], y = f[v-(1<<k)+1][k]; 55 return d[x] < d[y] ? ord[x] : ord[y]; 56 } 57 void Main() { 58 tot = 0; dfs(1, 0, 1); init(); 59 } 60 } 61 62 namespace SegmentTree{ 63 queue<int> s; 64 int sum[N * 9], ls[N * 9], rs[N * 9], tot; 65 int NewNode() { 66 int k; 67 if (!s.empty()) k = s.front(), s.pop(); 68 else k = ++tot; 69 sum[k] = ls[k] = rs[k] = 0; 70 return k; 71 } 72 void Insert(int l,int r,int &rt,int p) { 73 if (!rt) rt = NewNode(); 74 if (l == r) { 75 sum[rt] ++; return ; 76 } 77 int mid = (l + r) >> 1; 78 if (p <= mid) Insert(l, mid, ls[rt], p); 79 else Insert(mid + 1, r, rs[rt], p); 80 sum[rt] = sum[ls[rt]] + sum[rs[rt]]; 81 } 82 int query(int l,int r,int rt,int L,int R) { 83 if (!rt) return 0; // !!! 84 if (L <= l && r <= R) return sum[rt]; 85 int mid = (l + r) >> 1, res = 0; 86 if (L <= mid) res = query(l, mid, ls[rt], L, R); 87 if (R > mid) res += query(mid + 1, r, rs[rt], L, R); 88 return res; 89 } 90 int Merge(int x,int y) { 91 if (!x || !y) return x + y; 92 ls[x] = Merge(ls[x], ls[y]); 93 rs[x] = Merge(rs[x], rs[y]); 94 sum[x] = sum[x] + sum[y]; // sum[x] = sum[ls[x]] + sum[rs[x]]; !!! 95 s.push(y); 96 return x; 97 } 98 int solve(int u,int fa) { 99 int rt = NewNode(); 100 for (int i=0,sz=adj[u].size(); i<sz; ++i) 101 if (adj[u][i] != fa) rt = Merge(rt, solve(adj[u][i], u)); 102 for (int i=0,sz=ext[u].size(); i<sz; ++i) 103 Insert(1, n, rt, id[ext[u][i]]); 104 for (int i=0,sz=q[u].size(); i<sz; ++i) { 105 int v = q[u][i].first; 106 if (LCA::Lca(u, v) == v) { 107 for (int t,j=0; j<adj[v].size(); ++j) 108 if ((t=adj[v][j]) == LCA::Lca(u, t) && deth[t] > deth[v]) {// deth[t] > deth[v] !!! 109 ans[q[u][i].second] = query(1, n, rt, 1, n) - query(1, n, rt, id[t], id[t] + siz[t] - 1); 110 break; 111 } 112 } 113 else ans[q[u][i].second] = query(1, n, rt, id[v], id[v] + siz[v] - 1); 114 } 115 return rt; 116 } 117 } 118 void dfs(int u,int fa) { 119 deth[u] = deth[fa] + 1; 120 siz[u] = 1; 121 id[u] = ++Time_Index; 122 for (int i=0,sz=adj[u].size(); i<sz; ++i) { 123 int v = adj[u][i]; 124 if (v == fa) continue; 125 dfs(v, u); 126 siz[u] += siz[v]; 127 } 128 } 129 130 int main() { 131 n = read(); 132 for (int i=1; i<n; ++i) { 133 int u = read(), v = read(); 134 adj[u].push_back(v), adj[v].push_back(u); 135 } 136 int m = read(); 137 for (int i=1; i<=m; ++i) { 138 int u = read(), v = read(); 139 ext[u].push_back(v), ext[v].push_back(u); 140 } 141 dfs(1, 0); 142 int Q = read(); 143 for (int i=1; i<=Q; ++i) { 144 int u = read(), v = read(); 145 if (deth[u] > deth[v]) q[u].push_back(mp(v,i)); 146 else q[v].push_back(mp(u, i)); 147 } 148 LCA::Main(); 149 SegmentTree::solve(1, 0); 150 for (int i=1; i<=Q; ++i) 151 printf("%d ", ans[i] + 1); 152 return 0; 153 }