Qtree系列第三题
读完题大概不难判断是一道树剖的题
这道题的关键是记录两种状态,以及黑点的序号(不是编号)
线段树啊当然
定义两个变量v,f,v表示距离根节点最近的黑点,默认-1,f则表示区间内是否含有黑点,有为1,无为0
那么,怎么才能取当前路径距离根节点最近的黑点的呢?线段树更新时,优先取左子区间的黑点,后取右子区间,更新答案时,优先取后出现的答案
那么直接看程序吧
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cctype>
#define ll long long
#define gc() getchar()
#define maxn 100005
using namespace std;
inline ll read(){
ll a=0;int f=0;char p=gc();
while(!isdigit(p)){f|=p=='-';p=gc();}
while(isdigit(p)){a=(a<<3)+(a<<1)+(p^48);p=gc();}
return f?-a:a;
}
void write(ll a){
if(a>9)write(a/10);
putchar(a%10+'0');
}
int n,m;
struct ahaha1{ //前向星存边
int to,next;
}e[maxn<<1];int head[maxn],tot;
inline void add(int u,int v){
e[tot].to=v;e[tot].next=head[u];head[u]=tot++;
}
int sz[maxn],dep[maxn],f[maxn],son[maxn];
void dfs(int u,int fa){
sz[u]=1;int maxa=0;
for(int i=head[u];~i;i=e[i].next){
int v=e[i].to;if(v==fa)continue;
f[v]=u;dep[v]=dep[u]+1;
dfs(v,u);sz[u]+=sz[v];
if(maxa<sz[v]){maxa=sz[v];son[u]=v;}
}
}
int top[maxn],in[maxn],b[maxn];
void dfs(int u,int fa,int topf){
in[u]=++tot;top[u]=topf;b[tot]=u;
if(!son[u])return;
dfs(son[u],u,topf);
for(int i=head[u];~i;i=e[i].next){
int v=e[i].to;if(v==fa||v==son[u])continue;
dfs(v,u,v);
}
}
#define lc p<<1
#define rc p<<1|1
struct ahaha2{
int v;bool f;
ahaha2(){v=-1;} //v的初始值为-1
}t[maxn<<2];
inline void pushup(int p){
t[p].f=t[lc].f|t[rc].f; //若左右子区间中含有黑点,则当前区间含有黑点
t[p].v=t[lc].f?t[lc].v:(t[rc].f?t[rc].v:-1); //优先取左子区间的黑点,使距离根节点尽可能近
}
void update(int p,int l,int r,int L){
if(l==r){t[p].f^=1;t[p].v=t[p].f?b[l]:-1;return;} //若改变后为黑点,则赋值为节点序号,否则重置
int m=l+r>>1;
if(L<=m)update(lc,l,m,L);
else update(rc,m+1,r,L);
pushup(p);
}
int query(int p,int l,int r,int L,int R){
if(l>R||r<L)return -1;
if(L<=l&&r<=R)return t[p].v;
int m=l+r>>1,l1=query(lc,l,m,L,R),r1=query(rc,m+1,r,L,R); //优先取左子区间
return l1==-1?r1:l1;
}
inline void solve_1(){
int x=read();
update(1,1,n,in[x]);
}
inline void solve_2(){
int x=read(),ans=-1,p; //ans优先取后出现的的答案
while(top[x]!=1){
p=query(1,1,n,in[top[x]],in[x]);
ans=(p==-1?ans:p);
x=f[top[x]];
}
p=query(1,1,n,1,in[x]);
ans=(p==-1?ans:p);
if(ans<0)putchar('-'),ans=-ans;
write(ans);putchar('
');
}
int main(){memset(head,-1,sizeof head);
n=read();m=read();
for(int i=1;i<n;++i){
int x=read(),y=read();
add(x,y);add(y,x);
}
tot=0;dfs(1,-1);dfs(1,-1,1);
for(int i=1;i<=m;++i){
int z=read();
switch(z){
case 0:solve_1();break;
case 1:solve_2();break;
}
}
return 0;
}