题目链接:https://codeforces.com/contest/1467/problem/E
对于每一对同色点对,点对内部的点是合法点,外部点是不合法点,将不合法点标记,最后未被标记的即为合法点
1.如果当前点子树内有同色点,那么当前点的子树外所有点不合法
2.如果当前点子树外有同色点,那么当前点的子树内所有点不合法
打标记使用 (dfs) 序 + 前缀和差分即可
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 200010;
int n, m, q;
int a[maxn], b[maxn], c[maxn];
int h[maxn], cnt = 0;
struct Node{
int to, next;
}e[maxn << 1];
void add(int u, int v){
e[++cnt].to = v;
e[cnt].next = h[u];
h[u] = cnt;
}
int tot, dep[maxn], dfn[maxn], sz[maxn], cc[maxn], sum[maxn];
void dfs(int u, int par){
dep[u] = dep[par] + 1;
sz[u] = 1;
dfn[u] = ++tot;
int pre = cc[a[u]]++;
for(int i = h[u] ; i != -1 ; i = e[i].next){
int v = e[i].to;
if(v == par) continue;
int now = cc[a[u]];
dfs(v, u);
sz[u] += sz[v];
if(now != cc[a[u]]){
++sum[1];
--sum[dfn[v]];
++sum[dfn[v] + sz[v]];
}
}
if(c[a[u]] != cc[a[u]] - pre){
++sum[dfn[u]];
--sum[dfn[u] + sz[u]];
}
}
ll read(){ ll s = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9'){ if(ch == '-') f = -1; ch = getchar(); } while(ch >= '0' && ch <= '9'){ s = s * 10 + ch - '0'; ch = getchar(); } return s * f; }
int main(){
memset(h, -1, sizeof(h));
n = read();
for(int i = 1 ; i <= n ; ++i) a[i] = read(), b[i] = a[i];
int u, v;
for(int i = 1 ; i < n ; ++i){
u = read(), v = read();
add(u, v), add(v, u);
}
sort(b + 1, b + 1 + n);
q = unique(b + 1, b + 1 + n) - b - 1;
for(int i = 1 ; i <= n ; ++i){
a[i] = lower_bound(b + 1, b + 1 + n, a[i]) - b;
}
for(int i = 1 ; i <= n ; ++i) ++c[a[i]];
dfs(1, 0);
for(int i = 1 ; i <= n ; ++i) sum[i] += sum[i - 1];
int ans = 0;
for(int i = 1 ; i <= n ; ++i) if(!sum[i]) ++ans;
printf("%d
", ans);
return 0;
}