题目
题目链接:https://atcoder.jp/contests/agc002/tasks/agc002_d
一张连通图,(q) 次询问从两个点 (x) 和 (y) 出发,希望经过的点(不重复)数量等于 (z),经过的边最大编号最小是多少。
注意并不是要求找一条连接 (x) 和 (y) 的路径。
(n,m,Qleq 10^5)。
思路
这道题挺多做法的。线段树或者 kruskal 重构树都可以过。所以考虑整体二分。
对于每一个询问有一个很直接的思路就是二分最大边的编号,然后 (O(n+m)) check 一下是否可行。总复杂度是 (O(Q(n+m)alpha(n))) 的。
由于这道题不强制在线,所以我们考虑整体二分。具体的,假设我们现在确定了答案在边权 ([l,r]) 的询问的编号在 ([ql,qr]),那么我们取 (mid=frac{l+r}{2}),将 ([1,mid]) 的边所连接的点合并,然后枚举询问判断一下连通快大小是否不小于 (z) 即可。
但是这样的话,每次合并的复杂度是 (O(m)) 的,区间的数量为 (O(Qlog Q)),总复杂度就是 (O(Qmlog Qalpha(n))),不可接受。
发现在分治的时候,区间 ([l,r]) 会分治到 ([l,mid]) 和 ((mid,r]),而此时 ([1,mid]) 的边是连好的,如果我们往 ([l,mid]) 走,就要删除 ((frac{l+mid}{2},mid]) 的边,否则插入 ((mid,frac{mid+r}{2}]) 的边。考虑每条边的贡献,显然每条边最多被删除或插入 (O(log Q)) 次。如果可以支持快速删除插入,那么复杂度就正确了。
所以我们不能路径压缩,采用按秩合并。这样树高保持在 (O(log n)) 级别。然后用一个栈记录下被合并的所有边以及合并顺序,如果要删除一部分边,那么就不断弹出栈顶并还原;否则继续插入。
巧妙的是我们插入边的时候是按照边的编号来依次插入的,所以这样保证了正确性。
时间复杂度 (O(mlog nlog Q))。
代码
#include <bits/stdc++.h>
#define mp(x,y,z) make_pair(x,make_pair(y,z))
#define ST first
#define ND second.first
#define RD second.second
using namespace std;
const int N=100010;
int n,m,Q,father[N],siz[N],ans[N];
stack<pair<int,pair<int,int> > > st;
struct edge
{
int u,v;
}e[N];
struct Query
{
int x,y,z,id;
}ask[N],b[N];
int find(int x)
{
return father[x]==x?x:find(father[x]);
}
void merge(int x,int y,int id)
{
x=find(x); y=find(y);
if (x==y) return;
if (siz[x]<siz[y])
{
siz[y]+=siz[x]; father[x]=y;
st.push(mp(id,x,y));
}
else
{
siz[x]+=siz[y]; father[y]=x;
st.push(mp(id,y,x));
}
}
void binary(int l,int r,int ql,int qr)
{
int mid=(l+r)>>1,p=ql-1,q=qr+1;
while (st.size() && st.top().ST>mid)
{
pair<int,pair<int,int> > t=st.top();
siz[t.RD]-=siz[t.ND];
father[t.ND]=t.ND;
st.pop();
}
for (int i=l;i<=mid;i++)
merge(e[i].u,e[i].v,i);
for (int i=ql;i<=qr;i++)
{
int x=find(ask[i].x),y=find(ask[i].y);
if (x==y)
{
if (siz[x]>=ask[i].z)
b[++p]=ask[i],ans[ask[i].id]=mid;
else
b[--q]=ask[i];
}
else
{
if (siz[x]+siz[y]>=ask[i].z)
b[++p]=ask[i],ans[ask[i].id]=mid;
else
b[--q]=ask[i];
}
}
for (int i=ql;i<=qr;i++) ask[i]=b[i];
if (l==r) return;
binary(l,mid,ql,p); binary(mid+1,r,q,qr);
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=m;i++)
scanf("%d%d",&e[i].u,&e[i].v);
scanf("%d",&Q);
for (int i=1;i<=Q;i++)
{
scanf("%d%d%d",&ask[i].x,&ask[i].y,&ask[i].z);
ask[i].id=i; ans[i]=m;
}
for (int i=1;i<=n;i++)
father[i]=i,siz[i]=1;
binary(1,m,1,Q);
for (int i=1;i<=Q;i++)
printf("%d
",ans[i]);
return 0;
}