测试地址:寝室管理
题目大意:给定一棵树或环套树,求图中经过至少个点的路径数。
做法:本题需要用到环套树+点分治+树状数组。
先考虑树上的做法。对于这种树上路径计数的问题,应该能形成一种条件反射了,不能DP马上想到点分治。点分治中,每一次我们考虑过某个点的合法路径数时,先把子树列成一列,对于一棵子树里的所有点,它到根的距离和之前子树中的点到根的距离应该满足才是合法的,那么实际上我们就是要求之前子树中满足的点的数量,这样一个明显的后缀和形式显然可以用树状数组维护。那么我们就得到了一个的树上的算法。
那么再考虑环套树。首先对于所有外向树,我们都可以点分治出该外向树中的所有合法路径,因此我们只需要再考虑过环上的路径即可。为了不算重,我们需要计算从每个环上点的外向树中的点,顺时针(或逆时针,总之就是按同一个方向)走环,最后走到某个其他点的合法路径数。按套路破环为链并倍长,然后顺次编号,那么如果两个点到它们对应的外向树的根的距离分别为,而它们外向树的编号分别为,不妨设(时就是同一棵外向树了,我们已经算过了),那么之间的路径合法当且仅当成立。也就是成立。因此我们可以把看做每个点的权值,这样我们就可以相似地用树状数组求出满足条件的点数了。上述算法的时间复杂度为,加上点分治,总的时间复杂度还是,可以通过此题。
我傻逼的地方:太久没写大代码了,重心又求错了,TLE了两发……我可能要NOIP退役了……
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,k,first[100010]={0},tot=0,limit=0;
int st[100010],top,siz[100010],mxson[100010];
int inst[100010]={0},loop[100010],looplen;
ll ans=0,sum[400010]={0};
bool vis[100010]={0},inloop[100010]={0};
struct edge
{
int v,next;
}e[200010];
void insert(int a,int b)
{
e[++tot].v=b;
e[tot].next=first[a];
first[a]=tot;
}
int lowbit(int x)
{
return x&(-x);
}
void add(int x,ll d)
{
for(int i=x;i<=(n<<2);i+=lowbit(i))
sum[i]+=d;
}
ll query(int x)
{
ll ans=0;
for(int i=x;i;i-=lowbit(i))
ans+=sum[i];
return ans;
}
ll Sum(int l,int r)
{
if (r<1||l>r) return 0;
return query(r)-query(l-1);
}
void dp(int v,int fa)
{
st[++top]=v;
siz[v]=1,mxson[v]=0;
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v]&&e[i].v!=fa)
{
if (inloop[v]&&inloop[e[i].v]) continue;
dp(e[i].v,v);
mxson[v]=max(mxson[v],siz[e[i].v]);
siz[v]+=siz[e[i].v];
}
}
int find_ctr(int v)
{
top=0;
dp(v,0);
int mn=1000000000,mni;
for(int i=1;i<=top;i++)
if (max(mxson[st[i]],siz[v]-siz[st[i]])<mn)
{
mn=max(mxson[st[i]],siz[v]-siz[st[i]]);
mni=st[i];
}
return mni;
}
void maintain(int v,int fa,int dis,ll d)
{
add(dis,d);
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v]&&e[i].v!=fa)
{
if (inloop[v]&&inloop[e[i].v]) continue;
maintain(e[i].v,v,dis+1,d);
}
}
void calc(int v,int fa,int dis)
{
ans+=Sum(max(1,k-dis-1+limit),n<<2);
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v]&&e[i].v!=fa)
{
if (inloop[v]&&inloop[e[i].v]) continue;
calc(e[i].v,v,dis+1);
}
}
void solve(int v)
{
v=find_ctr(v);
vis[v]=1;
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v])
{
if (inloop[v]&&inloop[e[i].v]) continue;
solve(e[i].v);
}
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v])
{
if (inloop[v]&&inloop[e[i].v]) continue;
calc(e[i].v,0,1);
maintain(e[i].v,0,1,1);
}
ans+=Sum(k-1,n);
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v])
{
if (inloop[v]&&inloop[e[i].v]) continue;
maintain(e[i].v,0,1,-1);
}
vis[v]=0;
}
bool find_loop(int v,int fa)
{
st[++top]=v;
inst[v]=top;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa)
{
if (!inst[e[i].v])
{
if (find_loop(e[i].v,v))
return 1;
}
else
{
looplen=0;
for(int j=inst[e[i].v];j<=top;j++)
{
loop[++looplen]=st[j];
inloop[st[j]]=1;
}
return 1;
}
}
top--;
inst[v]=0;
return 0;
}
int main()
{
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=m;i++)
{
int a,b;
scanf("%d%d",&a,&b);
insert(a,b),insert(b,a);
}
if (m<n)
{
solve(1);
}
else
{
top=0;
find_loop(1,0);
for(int i=1;i<=looplen;i++)
solve(loop[i]);
for(int i=1;i<=looplen;i++)
maintain(loop[i],0,i,1);
for(limit=1;limit<=looplen;limit++)
{
maintain(loop[limit],0,limit,-1);
calc(loop[limit],0,0);
maintain(loop[limit],0,looplen+limit,1);
}
}
printf("%lld",ans);
return 0;
}