SPOJ10707 COT2 Count on a tree II
Solution
我会强制在线版本! Solution戳这里
代码实现
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<iostream>
using namespace std;
#define ll long long
#define re register
#define file(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout)
inline int gi()
{
int f=1,sum=0;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return f*sum;
}
const int N=60010;
int Bl[N],B,P[N],ans[310][N],a[N],b[N],bl[N],num,p[N][310],Anum,rt[310],F[N];
struct array
{
int num[210];
int operator[](int x){return p[num[Bl[x]]][P[x]];};
void insert(const array &pre,int x,int dep)
{
int block=Bl[x],t=P[x];
memcpy(num,pre.num,sizeof(num));
memcpy(p[++Anum],p[num[block]],sizeof(p[0]));
p[Anum][t]=dep;num[block]=Anum;
}
}s[N];
int to[N<<1],nxt[N<<1],front[N],cnt,dep[N],f[N][22],st[N],sta,kind;
inline void Add(int u,int v)
{
to[++cnt]=v;nxt[cnt]=front[u];front[u]=cnt;
}
inline int dfs(int u,int fa)
{
dep[u]=dep[fa]+1;
f[u][0]=fa;
s[u].insert(s[fa],a[u],dep[u]);
st[++sta]=u;int mx=dep[u],now=sta;
for(re int i=front[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa)continue;
mx=max(mx,dfs(v,u));
}
if(mx-dep[u]>=B || now==1)
{
rt[++num]=u;
for(re int i=now;i<=sta;i++)bl[st[i]]=num;
sta=now-1;return dep[u]-1;
}
return mx;
}
int lca(int u,int v)
{
if(dep[u]<dep[v])swap(u,v);
for(re int i=20;~i;i--)
if(dep[u]-(1<<i)>=dep[v])u=f[u][i];
if(u==v)return u;
for(re int i=20;~i;i--)
if(f[u][i]!=f[v][i])
u=f[u][i],v=f[v][i];
return f[u][0];
}
inline void getans(int u,int fa,int BL)
{
if(++F[a[u]]==1)kind++;
ans[BL][u]=kind;
for(re int i=front[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa)continue;
getans(v,u,BL);
}
if(--F[a[u]]==0)kind--;
}
int solve_same(int x,int y)
{
sta=0;
for(kind=0;x!=y;x=f[x][0])
{
if(dep[x]<dep[y])swap(x,y);
if(!F[a[x]]++)++kind,st[++sta]=a[x];
}
int QAQ=kind+(!F[a[x]]);
for(;sta;sta--)F[st[sta]]=0;
return QAQ;
}
int solve_diff(int x,int y)
{
if(dep[rt[bl[x]]]<dep[rt[bl[y]]])swap(x,y);
int sum=ans[bl[x]][y];
int z=rt[bl[x]],d=dep[lca(x,y)];
sta=0;
for(;x!=z;x=f[x][0])
{
if(!F[a[x]] && s[z][a[x]]<d && s[y][a[x]]<d)
F[st[++sta]=a[x]]=1,sum++;
}
for(;sta;sta--)F[st[sta]]=0;
return sum;
}
int n,m;
void print(int x)
{
if(x>=10)print(x/10);
putchar(x%10+'0');
}
int main()
{
n=gi();m=gi();B=sqrt(n);
for(int i=1;i<=n;i++)Bl[i]=(i-1)/B+1,P[i]=i%B;
for(re int i=1;i<=n;i++)a[i]=b[i]=gi();
sort(b+1,b+n+1);int N=unique(b+1,b+n+1)-b-1;
for(re int i=1;i<=n;i++)
a[i]=lower_bound(b+1,b+N+1,a[i])-b;
for(re int i=1;i<n;i++)
{
int u=gi(),v=gi();
Add(u,v);Add(v,u);
}
dfs(1,1);
for(re int i=1;i<=num;i++)getans(rt[i],rt[i],i);
for(re int j=1;j<=20;j++)
for(re int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
int lastans=0;
while(m--)
{
int u=gi(),v=gi();
if(bl[u]==bl[v])lastans=solve_same(u,v);
else lastans=solve_diff(u,v);
print(lastans);putchar('
');
}
return 0;
}