好像主流做法是点分治,然后我在考场上写了一个时间和空间常数都很大的 tarjan + 倍增优化建图的垃圾做法。
下面介绍我的垃圾做法:
首先考虑对于两个颜色 (x) , (y) ,若选了颜色 (x) 就必须选 (y) ,则在图中连一条 (x
ightarrow y) 的边。那么可以对这个图进行 tarjan 缩点求出 scc,然后答案就是没有出边的 scc 中包含点数最少的那个。 (证明略)
然后考虑如何 (n^2) 地建图:
显然,若一个点 (x) 在颜色相同的两个点 (y,z) 之间的最短路上,则需要有一条边 (col_y
ightarrow col_x),然后考虑设所有颜色为 (c) 的点的集合为 (S_c) , (S_c) 中所有点的 LCA 是 (l) ,那么 (forall x in S_c), (y) 在 (l) 和 (x) 之间的路径上,有一条边 (c
ightarrow col_y)。
那么对于一个点 (x) ,他连出去的边分布在 (|S_x|) 条链上,只要对它进行倍增优化建图即可在 (O(nlog n)) 的复杂度内解决这个问题。
(这个做法真的又慢又费空间,虽然他时间和空间的复杂度都是 (O(nlog n)) 的,但它在考场上跑了 2s,空间用了 400+MB,但也有可能是 STL 过度使用的原因)
Code:
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define mp make_pair
#define Fast_IO ios::sync_with_stdio(false);
#define DEBUG fprintf(stderr,"Running on Line %d in Function %s
",__LINE__,__FUNCTION__)
#define fir first
#define sec second
#define mod 998244353
#define INF 0x3fffffff
#define ll long long
inline int read()
{
char ch=getchar(); int nega=1; while(!isdigit(ch)) {if(ch=='-') nega=-1; ch=getchar();}
int ans=0; while(isdigit(ch)) {ans=ans*10+ch-48;ch=getchar();}
if(nega==-1) return -ans;
return ans;
}
typedef pair<int,int> pii;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int sub(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1LL*x*y%mod;}
const int N=200005*20;
const int M=200005;
vector<int> G[M];
int col[M],n,m,lg[N];
vector<int> v[M];
int l[M];
int siz[M],dep[M],son[M],top[M],fa[M],f[M][20];
int cnt;
int t[M][20];
void dfs1(int u,int Fa)
{
siz[u]=1; fa[u]=Fa; dep[u]=dep[Fa]+1; f[u][0]=Fa;
for(int v:G[u])
{
if(v==Fa) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int topf)
{
top[u]=topf;
if(!son[u]) return ;
dfs2(son[u],topf);
for(int v:G[u])
{
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
int lca(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]>dep[top[v]]) u=fa[top[u]];
else v=fa[top[v]];
}
return dep[u]>dep[v]?v:u;
}
vector<int> H[N];
int dgr[N],dfn[N],low[N],vis[N],_cnt,bel[N],_sum,dat[N];
stack<int> s;
void tarjan(int u)
{
if(!u) return ;
low[u]=dfn[u]=++_cnt;
vis[u]=1; s.push(u);
for(int v:H[u])
{
if(v==u) continue;
if(!dfn[v])
{
tarjan(v);
low[u]=min(low[u],low[v]);
}
else if(vis[v]) low[u]=min(low[u],dfn[v]);
}
if(low[u]==dfn[u])
{
_sum++;
while(1)
{
int x=s.top(); s.pop();
bel[x]=_sum;
if(x<=m) dat[_sum]++;
vis[x]=0;
if(x==u) break;
}
}
}
int jump(int u,int dep)
{
for(int i=18;i>=0;i--)
{
if((1<<i)<=dep)
{
dep-=1<<i;
u=f[u][i];
}
}
return u;
}
void addedge(int fr,int u,int v)
{
int x=lg[dep[v]-dep[u]+1];
H[fr].pb(t[v][x]);
int qaq=jump(v,(dep[v]-dep[u]+1)-(1<<x));
H[fr].pb(t[qaq][x]);
}
signed main()
{
lg[1]=0; for(int i=2;i<N;i++) lg[i]=lg[i/2]+1;
cin>>n>>m;
for(int i=1;i<n;i++)
{
int u=read(),v=read();
G[u].pb(v),G[v].pb(u);
}
for(int i=1;i<=n;i++)
{
col[i]=read();
v[col[i]].pb(i);
}
dfs1(1,0),dfs2(1,1);
for(int i=1;i<=m;i++)
{
l[i]=v[i][0];
for(int j=1;j<(int)v[i].size();j++) l[i]=lca(l[i],v[i][j]);
}
for(int i=1;i<=18;i++)
{
for(int j=1;j<=n;j++) f[j][i]=f[f[j][i-1]][i-1];
}
for(int i=1;i<=n;i++) t[i][0]=col[i];
cnt=m;
for(int i=1;i<=18;i++)
{
for(int j=1;j<=n;j++)
{
t[j][i]=++cnt;
H[t[j][i]].pb(t[j][i-1]);
H[t[j][i]].pb(t[f[j][i-1]][i-1]);
}
}
for(int i=1;i<=n;i++)
{
addedge(col[i],l[col[i]],i);
}
for(int i=1;i<=cnt;i++) if(!dfn[i]) tarjan(i);
for(int i=1;i<=cnt;i++)
{
for(int v:H[i])
{
if(bel[i]!=bel[v]) dgr[bel[i]]++;
}
}
int ans=INF;
for(int i=1;i<=_sum;i++) if(!dgr[i]&&dat[i]>0) ans=min(ans,dat[i]);
cout<<ans-1<<endl;
return 0;
}