题意:一棵树q次查询,每次查询给三个不同的点,要求计算到这三个点的比其他两个距离都要小的点数
题解:很明显的lca,倍增的找中点,关键是两个点的中点很好找,但是三个点不好找,我刚开始还准备分类讨论,后来发现巨麻烦,其实可以用线段树来维护算a的答案其实就是a在b下的答案和a在c下的答案的交集,可以用线段树区间求和区间查询做,每次更新完之后复原就不用memset线段树了
//#pragma comment(linker, "/stack:200000000") //#pragma GCC optimize("Ofast,no-stack-protector") //#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native") //#pragma GCC optimize("unroll-loops") #include<bits/stdc++.h> #define fi first #define se second #define mp make_pair #define pb push_back #define pi acos(-1.0) #define ll long long #define vi vector<int> #define mod 1000000007 #define C 0.5772156649 #define ls l,m,rt<<1 #define rs m+1,r,rt<<1|1 #define pil pair<int,ll> #define pli pair<ll,int> #define pii pair<int,int> #define cd complex<double> #define ull unsigned long long #define base 1000000000000000000 #define fio ios::sync_with_stdio(false);cin.tie(0) using namespace std; const double g=10.0,eps=1e-12; const int N=100000+10,maxn=5000000+10,inf=0x3f3f3f3f,INF=0x3f3f3f3f3f3f3f3f; vi v[N]; int dep[N],n,sz[N],fa[20][N]; int le[N],ri[N],id[N],cnt; void dfs(int u,int f) { le[u]=++cnt; id[cnt]=u; fa[0][u]=f; sz[u]=1; for(int i=0;i<v[u].size();i++) { int x=v[u][i]; if(x!=f)dep[x]=dep[u]+1,dfs(x,u),sz[u]+=sz[x]; } ri[u]=cnt; } int lazy[N<<2],val[N<<2]; void pushdown(int l,int r,int rt) { if(lazy[rt]!=0) { int m=(l+r)>>1; val[rt<<1]+=(m-l+1)*lazy[rt]; val[rt<<1|1]+=(r-m)*lazy[rt]; lazy[rt<<1]+=lazy[rt]; lazy[rt<<1|1]+=lazy[rt]; lazy[rt]=0; } } void pushup(int rt) { val[rt]=val[rt<<1]+val[rt<<1|1]; } void build(int l,int r,int rt) { lazy[rt]=val[rt]=0; if(l==r)return ; int m=(l+r)>>1; build(ls);build(rs); } void update(int c,int L,int R,int l,int r,int rt) { if(L<=l&&r<=R) { val[rt]+=(r-l+1)*c; lazy[rt]+=c; return ; } pushdown(l,r,rt); int m=(l+r)>>1; if(L<=m)update(c,L,R,ls); if(m<R)update(c,L,R,rs); pushup(rt); } int query(int L,int R,int l,int r,int rt) { if(L<=l&&r<=R)return val[rt]; pushdown(l,r,rt); int m=(l+r)>>1,ans=0; if(L<=m)ans+=query(L,R,ls); if(m<R)ans+=query(L,R,rs); return ans; } void init() { dep[1]=1; cnt=0; dfs(1,-1); build(1,cnt,1); for(int i=1;i<20;i++) for(int j=1;j<=n;j++) fa[i][j]=fa[i-1][fa[i-1][j]]; } int lca(int x,int y) { if(dep[x]>dep[y])swap(x,y); for(int i=0;i<20;i++) if((dep[y]-dep[x])>>i&1) y=fa[i][y]; if(x==y)return x; for(int i=19;i>=0;i--) { if(fa[i][x]!=fa[i][y]) { x=fa[i][x]; y=fa[i][y]; } } return fa[0][x]; } int go(int u,int dis) { for(int i=19;i>=0;i--) if(dis>=(1<<i)) dis-=(1<<i),u=fa[i][u]; return u; } int solve(int a,int b,int c) { int tle,tri,ans=0; if(dep[a]>=dep[b]) { int dis=dep[a]+dep[b]-2*dep[lca(a,b)]; int x=go(a,dis/2); if(dis%2==0)x=go(a,dis/2-1); update(1,le[x],ri[x],1,cnt,1); tle=le[x],tri=ri[x]; } else { int dis=dep[a]+dep[b]-2*dep[lca(a,b)]; int x=go(b,dis/2); update(1,1,cnt,1,cnt,1); update(-1,le[x],ri[x],1,cnt,1); tle=le[x],tri=ri[x]; } if(dep[a]>=dep[c]) { int dis=dep[a]+dep[c]-2*dep[lca(a,c)]; int x=go(a,dis/2); if(dis%2==0)x=go(a,dis/2-1); ans=query(le[x],ri[x],1,cnt,1); } else { int dis=dep[a]+dep[c]-2*dep[lca(a,c)]; int x=go(c,dis/2); ans=query(1,cnt,1,cnt,1); ans-=query(le[x],ri[x],1,cnt,1); } if(dep[a]>=dep[b])update(-1,tle,tri,1,cnt,1); else update(-1,1,cnt,1,cnt,1),update(1,tle,tri,1,cnt,1); return ans; } int main() { int T;scanf("%d",&T); while(T--) { scanf("%d",&n); for(int i=1;i<=n;i++)v[i].clear(); for(int i=1;i<n;i++) { int a,b; scanf("%d%d",&a,&b); v[a].pb(b),v[b].pb(a); } init(); int q;scanf("%d",&q); while(q--) { int a,b,c; scanf("%d%d%d",&a,&b,&c); // printf("%d ",solve(b,a,c)); printf("%d %d %d ",solve(a,b,c),solve(b,a,c),solve(c,a,b)); } } return 0; } /*********************** 1 9 1 2 1 3 1 4 2 5 2 6 2 7 6 8 6 9 2 1 2 8 2 1 4 ***********************/