有N个点的树,现在我们有M次询问,询问(u、v),只在u、v子树内出现的颜色的个数(u子树 并 v子树)。
首先,可以将问题拆开来讨论,如果对于一个点的时候是怎样的呢?
先求一个点时候的答案:
那么,实际上就是求“颜色总数-不在子树内出现的颜色总数”。那么,实际上,我们可以将点数扩展成2N这么多个,然后查询出现在dfn[u] + siz[u] ~ dfn[u] + N - 1内的颜色的个数。当然,不嫌麻烦的话,就不需要这么做,直接写树剖,然后找到每个颜色的总的父亲节点,给1~point这条链全部“+1”,然后之后单点查询即可。——于是,我们就可以求得每个点的自身子树所产生的答案了。
然后考虑两个点的时候,当且仅当我们对一个颜色构建虚树,虚树有虚根(虚根颜色不为该颜色),且虚根有且仅有两个虚子节点的时候,才需要考虑这种颜色的贡献。
于是,求两个点时候的答案:
我们不妨对虚子节点进行配对,譬如说“u指向v”这样的,于是,对于查询的x、y,我们实际上就是想知道x的子树内,有多少是指向y子树内的。那么,对于x的子树区间为dfn[x] ~ dfn[x] + siz[x] - 1,同理,y的子树。那么,比较显然的,这不就是用dfs序来构造一个主席树,然后在主席树上进行差分,找到x子树对于y子树的贡献吗?
所以,我们可以利用虚树来找到这样的u、v对,然后我们让dfs序小的指向dfs序大的,最后,我们对它的dfs序上建立可持久化线段树,然后对于查询的x、y对,我们用dfs序小的作为根的差分,去查dfs序大的部分。——于是,利用这些,我们就可以知道这样的关系对对每个答案产生的贡献了。
1 #include <iostream> 2 #include <cstdio> 3 #include <cmath> 4 #include <string> 5 #include <cstring> 6 #include <algorithm> 7 #include <limits> 8 #include <vector> 9 #include <stack> 10 #include <queue> 11 #include <set> 12 #include <map> 13 #include <bitset> 14 #include <unordered_map> 15 #include <unordered_set> 16 #define lowbit(x) ( x&(-x) ) 17 #define pi 3.141592653589793 18 #define e 2.718281828459045 19 #define INF 0x3f3f3f3f 20 #define HalF (l + r)>>1 21 #define lsn rt<<1 22 #define rsn rt<<1|1 23 #define Lson lsn, l, mid 24 #define Rson rsn, mid+1, r 25 #define QL Lson, ql, qr 26 #define QR Rson, ql, qr 27 #define myself rt, l, r 28 #define pii pair<int, int> 29 #define MP(a, b) make_pair(a, b) 30 using namespace std; 31 typedef unsigned long long ull; 32 typedef unsigned int uit; 33 typedef long long ll; 34 const int maxN = 1e5 + 7; 35 int N, M, head[maxN], cnt, col[maxN], lsan[maxN], _UP; 36 vector<int> vt[maxN]; 37 struct Eddge 38 { 39 int nex, to; 40 Eddge(int a=-1, int b=0):nex(a), to(b) {} 41 } edge[maxN << 1]; 42 inline void addEddge(int u, int v) 43 { 44 edge[cnt] = Eddge(head[u], v); 45 head[u] = cnt ++; 46 } 47 inline void _add(int u, int v) { addEddge(u, v); addEddge(v, u); } 48 inline void init() 49 { 50 cnt = 0; 51 for(int i = 1; i <= N; i ++) head[i] = -1; 52 } 53 int deep[maxN], fa[maxN][20], LOG_2[maxN], siz[maxN], dfn[maxN], idx, rid[maxN << 1]; 54 bool cmp(int a, int b) { return dfn[a] < dfn[b]; } 55 void dfs(int u, int father) 56 { 57 siz[u] = 1; 58 fa[u][0] = father; 59 deep[u] = deep[father] + 1; 60 dfn[u] = ++idx; 61 rid[idx] = u; 62 for(int i = 0; i < 18; i ++) fa[u][i + 1] = fa[fa[u][i]][i]; 63 for(int i = head[u], v; ~i; i = edge[i].nex) 64 { 65 v = edge[i].to; 66 if(v == father) continue; 67 dfs(v, u); 68 siz[u] += siz[v]; 69 } 70 } 71 int lca(int u, int v) 72 { 73 if(deep[u] < deep[v]) swap(u, v); 74 int det = deep[u] - deep[v]; 75 for(int i = LOG_2[det]; i >= 0; i --) 76 { 77 if((det >> i) & 1) u = fa[u][i]; 78 } 79 if(u == v) return u; 80 for(int i = LOG_2[deep[v]]; i >= 0; i --) 81 { 82 if(fa[u][i] ^ fa[v][i]) 83 { 84 u = fa[u][i]; 85 v = fa[v][i]; 86 } 87 } 88 return fa[u][0]; 89 } 90 int stk[maxN], top; 91 vector<int> son; 92 void Insert(int u) 93 { 94 if(top <= 1) 95 { 96 stk[++ top] = u; 97 return; 98 } 99 int p = lca(u, stk[top]); 100 while(top >= 2 && dfn[p] <= dfn[stk[top - 1]]) 101 { 102 if(top == 2) son.push_back(stk[top]); 103 top --; 104 } 105 if(stk[top] ^ p) 106 { 107 stk[top] = p; 108 } 109 stk[++ top] = u; 110 } 111 vector<int> nex[maxN]; 112 void solve(int op) 113 { 114 sort(vt[op].begin(), vt[op].end(), cmp); 115 int root = vt[op][0]; 116 for(int u : vt[op]) root = lca(root, u); 117 top = 0; son.clear(); 118 Insert(root); 119 for(int u : vt[op]) 120 { 121 if(u == root) 122 { 123 top = 0; 124 son.clear(); 125 break; 126 } 127 Insert(u); 128 } 129 if(top >= 2) son.push_back(stk[2]); 130 if(son.size() == 2) 131 { 132 int x = son[0], y = son[1]; 133 if(dfn[x] > dfn[y]) swap(x, y); 134 nex[dfn[x]].push_back(dfn[y]); 135 } 136 } 137 int t[maxN << 1]; 138 void add(int x, int v) { while(x <= (N << 1)) { t[x] += v; x += lowbit(x); } } 139 int sum(int x) { int res = 0; while(x) { res += t[x]; x -= lowbit(x); } return res; } 140 vector<pii> ques[maxN << 1]; 141 int las_col[maxN] = {0}; 142 int ans[maxN]; 143 namespace Segement 144 { 145 const int maxP = maxN * 30; 146 int tree[maxP], lc[maxP], rc[maxP]; 147 int root[maxN], tot; 148 void build(int &rt, int old, int l, int r, int qx) 149 { 150 rt = ++ tot; 151 lc[rt] = lc[old]; rc[rt] = rc[old]; tree[rt] = tree[old] + 1; 152 if(l == r) return; 153 int mid = HalF; 154 if(qx <= mid) build(lc[rt], lc[old], l, mid, qx); 155 else build(rc[rt], rc[old], mid + 1, r, qx); 156 } 157 int query(int rl, int rr, int l, int r, int ql, int qr) 158 { 159 if(ql <= l && qr >= r) return tree[rr] - tree[rl]; 160 int mid = HalF; 161 if(qr <= mid) return query(lc[rl], lc[rr], l, mid, ql, qr); 162 else if(ql > mid) return query(rc[rl], rc[rr], mid + 1, r, ql, qr); 163 else return query(lc[rl], lc[rr], l, mid, ql, qr) + query(rc[rl], rc[rr], mid + 1, r, ql, qr); 164 } 165 } 166 using namespace Segement; 167 struct Question 168 { 169 int u, v, id; 170 Question(int a=0, int b=0, int c=0):u(a), v(b), id(c) {} 171 }; 172 vector<Question> qt; 173 int main() 174 { 175 for(int i = 2; i < maxN; i ++) LOG_2[i] = LOG_2[i >> 1] + 1; 176 scanf("%d%d", &N, &M); 177 init(); 178 for(int i = 1; i <= N; i ++) { scanf("%d", &col[i]); lsan[i] = col[i]; } 179 sort(lsan + 1, lsan + N + 1); 180 _UP = (int)(unique(lsan + 1, lsan + N + 1) - lsan - 1); 181 for(int i = 1; i <= N; i ++) 182 { 183 col[i] = (int)(lower_bound(lsan + 1, lsan + _UP + 1, col[i]) - lsan); 184 vt[col[i]].push_back(i); 185 } 186 for(int i = 1, u, v; i < N; i ++) 187 { 188 scanf("%d%d", &u, &v); 189 _add(u, v); 190 } 191 dfs(1, 0); 192 for(int i = 1; i <= N; i ++) rid[N + i] = rid[i]; 193 for(int i = 1; i <= _UP; i ++) solve(i); 194 for(int i = 1, x, y, p, l, r; i <= M; i ++) 195 { 196 scanf("%d%d", &x, &y); 197 if(dfn[x] > dfn[y]) swap(x, y); 198 p = lca(x, y); 199 if(p == x) 200 { 201 ans[i] = _UP; 202 l = dfn[p] + siz[p]; r = N + dfn[p] - 1; 203 ques[r].push_back(MP(l, i)); 204 } 205 else 206 { 207 ans[i] = _UP << 1; 208 l = dfn[x] + siz[x]; r = N + dfn[x] - 1; 209 ques[r].push_back(MP(l, i)); 210 l = dfn[y] + siz[y]; r = N + dfn[y] - 1; 211 ques[r].push_back(MP(l, i)); 212 qt.push_back(Question(x, y, i)); 213 } 214 } 215 for(int i = 1, u, c, id, lx; i <= (N << 1); i ++) 216 { 217 u = rid[i]; 218 c = col[u]; 219 add(i, 1); 220 if(las_col[c]) add(las_col[c], -1); 221 las_col[c] = i; 222 for(pii it : ques[i]) 223 { 224 id = it.second; 225 lx = it.first; 226 ans[id] -= sum(i) - sum(lx - 1); 227 } 228 } 229 for(int i = 1; i <= N; i ++) 230 { 231 root[i] = root[i - 1]; 232 for(int j : nex[i]) 233 { 234 build(root[i], root[i], 1, N, j); 235 } 236 } 237 for(Question it : qt) 238 { 239 int u = it.u, v = it.v, id = it.id; 240 int l = dfn[u], r = dfn[u] + siz[u] - 1, ql = dfn[v], qr = dfn[v] + siz[v] - 1; 241 ans[id] += query(root[l - 1], root[r], 1, N, ql, qr); 242 } 243 for(int i = 1; i <= M; i ++) printf("%d ", ans[i]); 244 return 0; 245 }