把问题转化
-
选取一个子树包含所有颜色,答案为子树外到子树根最长链的长度
-
选取一个子树使得原树去掉该子树包含所有颜色,答案为子树内到子树根最长链的长度
那么以上情况取最大值即为答案。把问题转化成了求子树的颜色种类,是典型的DSU on Tree.
分开来做即可。
但仍需常数优化,预处理从下向上的最长链和次长链,两种情况一起做。
#include <algorithm>
#include <cstdio>
using namespace std;
const int maxn = 1000000, maxm = 100000;
class Tree {
public:
int n, m, col[maxn + 1], ind[maxn + 1], to[2 * maxn - 1], link[2 * maxn - 1],
son[maxn + 1], size[maxn + 1], len[maxn + 1], ans, f[maxm + 1],
g[maxm + 1], ver[maxm + 1], ml[maxn + 1], sl[maxn + 1];
void addSide(int u, int v) {
m++;
to[m] = v;
link[m] = ind[u];
ind[u] = m;
}
void init(int o, int fa) {
len[o] = 1, size[o] = 1, son[o] = 0, ml[o] = 0, sl[o] = 0;
for (int i = ind[o]; i; i = link[i]) {
if (to[i] != fa) {
init(to[i], o);
len[o] = max(len[o], len[to[i]] + 1);
if (len[ml[o]] <= len[to[i]]) {
sl[o] = ml[o], ml[o] = to[i];
} else if (len[sl[o]] <= len[to[i]]) {
sl[o] = to[i];
}
size[o] += size[to[i]];
if (size[son[o]] < size[to[i]]) {
son[o] = to[i];
}
}
}
}
void solve(int o, int fa, int up, bool keep, bool makeAns) {
int a = ml[o], b = sl[o];
for (int i = ind[o]; makeAns && i; i = link[i]) {
if (to[i] != fa && to[i] != son[o]) {
solve(to[i], o, max(up, to[i] == a ? len[b] : len[a]) + 1, false, true);
}
}
if (son[o]) {
solve(son[o], o, max(up, son[o] == a ? len[b] : len[a]) + 1, true, makeAns);
}
for (int i = ind[o]; i; i = link[i]) {
if (to[i] != fa && to[i] != son[o]) {
solve(to[i], o, 0, true, false);
}
}
if (ver[col[o]] != ver[0]) {
ver[col[o]] = ver[0];
g[col[o]] = 0;
}
if (!g[col[o]]) {
f[0]++;
}
g[col[o]]++;
if (g[col[o]] == f[col[o]]) {
g[0]++;
}
if (makeAns) {
if (f[0] == col[0]) {
ans = max(ans, up + 1);
}
if (!g[0]) {
ans = max(ans, len[o] + 1);
}
}
if (!keep) {
ver[0]++, f[0] = g[0] = 0;
}
}
};
int read() {
int ret = 0;
char c;
for (c = getchar(); c < '0' || c > '9'; c = getchar());
for (; c >= '0' && c <= '9'; c = getchar()) {
ret = ret * 10 + c - '0';
}
return ret;
}
int main() {
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
static Tree tree;
int n = read(), m = read();
tree.n = n;
tree.col[0] = m;
for (int i = 1; i <= n; i++) {
tree.col[i] = read();
tree.f[tree.col[i]]++;
}
for (int i = 1; i < n; i++) {
int u = read(), v = read();
tree.addSide(u, v);
tree.addSide(v, u);
}
tree.g[0] = m;
tree.init(1, 0);
tree.solve(1, 0, 0, true, true);
printf("%d
", tree.ans);
fclose(stdin);
fclose(stdout);
return 0;
}