原文链接https://www.cnblogs.com/zhouzhendong/p/NowCoder-2018-Summer-Round7-I.html
题目传送门 - https://www.nowcoder.com/acm/contest/145/I
题意
给定一棵有 $n$ 个节点的树,问有多少个点集的直径恰好等于 $D$ 。
一个点集的直径定义为该点集中距离最远的两个点的距离。
两个点的距离定义为他们在树上的最短路径经过的边数。
$nleq 10^5$
题解
我的做法有点难写,官方做法我没写过。
但是我的做法是 (截至 2018-08-10 10:40) 跑的最快的。
我们首先考虑一个简单的差分。记 $solve(D)$ 为直径不大于 $D$ 的答案。最终的答案显然为 $solve(D)-solve(D+1)$ 。
我们然后考虑先用 DP 来求一个 $solve(D)$ 。先给树定一个根,比如说 $1$ 号节点。令 $dp[i][j]$ 表示在子树 $i$ 中选择点集,使得这个点集的直径不超过 $D$,并使得与点 $i$ 最远的点与点 $i$ 的距离为 $j$ 的方案数。我们显然可以 树形dp 来解决这个问题。这个树形dp 的主要操作是合并两个子树。由于所有距离 $i$ 的点在 DP 上去的过程中都是等价的,我们可以把一串 DP 信息看做一条链。合并子树就是合并链。考虑两条链的合并,对于两条链 $x,y$ ,各选择一个最深深度,假设分别为 $i,j$ 。记合并 $x,y$ 之后的结果是 $z$ 数组,则对于每一对 $i,j$ ,我们需要更新的是: $dp[z][max(i,j)]+=dp[x][i] imes dp[y][j]$ 。
这样做,时间复杂度是 $O(n^2)$ 的。
我们重新考虑合并两条链的过程。我们枚举短的那一条链的每一个节点,然后分两种情况把贡献加到另一条链上,用线段树快速维护,线段树支持的操作为区间加和区间乘。由于每一次枚举的是短链上的情况,所以总的枚举次数不会大于整个树的大小。然后您大概还没有看懂我在说什么。嗯,每次线段树维护时间复杂度 $O(log n)$ ,所以总的时间复杂度为 $O(nlog n)$ 。
为了方便,我们先长链剖分一下,每次先 DFS 最长链。我们来分情况讨论一下我们要支持哪些操作:
1. 链顶加上一个节点:考虑当前节点所在的长链的下一个节点所代表的子树信息都已经合并到那个节点所代表的链上了,我们要把它接到当前节点上面来。首先,对于深度 $leq $ D 的节点,我们可以考虑再加上当前节点,所以我们要将他们区间 $ imes 2$ ;考虑不取那个子树的点,把只取当前节点的方案作为一个新集合,显然它与当前节点的距离为 $0$ ,所以我们可以单点加。
2. 合并当前节点(后面也把当前节点叫做根)的轻儿子的链。我们可以把所有轻儿子的链顶端想象着往上接一个不能选择的父亲点,这样,可以使得这条个轻儿子开头的链的起始深度与当前子树已经合并完成的链相同。然后我们考虑枚举轻儿子链上的每一个节点,分两种情况用这个节点更新重链。
假设当前节点距离根 $j$ 。下面考虑重链上与根的距离不大于 $D-j$ 的节点(这样可以保证任意两个点之间的距离不大于 D),把这些点构成的点集记为 S。
第一种情况:考虑选择S中距离根不大于 $j$ 的节点,假设 它距离根 $i$ ,由于 $max(i,j)=j$,那么显然更新得到的结果会加在重链中距离根 $j$ 的位置的结果上面;具体的操作是:询问前缀和,单点修改。
第二种情况:考虑选择S中距离根大于 $j$ 的节点,那么显然会更新到它自己所在的位置上面。假设当前轻链上第 $j$ 个位置的方案为 $v$ ,则对于所有这次选择的节点,就相当于区间乘 $v+1$ 。
于是我开心的把上面的做法实现了一下,然后 WA 了。错在了哪里?
回忆一下,之前的重链结果是要备份的。区间乘是假的,本质是加上原数组的 $v$ 倍。这是个棘手的问题。对于第一种情况,我们很容易处理,只需要把询问放到第二种情况的修改之前,操作放到之后做就可以了。
对于第二种情况,我们发现涉及的区间只有可能存在包含关系。于是我们考虑一下,如何快速修改。
考虑当前区间乘的倍数大概是这样:
1 1 a a a+b a+b a+b a+b a+b a 1 1 1
然后我们要在中间的 $a+b$ 的某一段区间再加上原数组的 $c$ 倍,我的做法是:先将对应区间除以 $a+b$ ,然后再乘上 $a+b+c$ 。
于是我们的问题就解决了。
但是这个操作其实在十分特殊的情况下是会挂掉的:因为乘 0 操作不能用 乘以 $0$ 的逆元 来取消,所以一旦 和为 $0$(在对于 $10^9+7$ 取膜的意义下),我就很可能挂掉了(但是我试了一下,数据里没有出现这种情况)。
如果出现这种情况,也有解决的办法:
考虑一下左边和右边出现相邻值不同的位置,这些位置只可能出现在左右各 [轻链size] 的范围,所以我们可以把所有的操作记下来,中间的很长一段直接区间乘,周围的单点加即可。但是这样写很麻烦,而且博主十分懒,肯定不会去写的呀。
至此,我大致的说明了我的做法。标算的点分治是什么我并不知道……
没看懂的就……请留言指出没看懂的地方,我有空的时候回复您吧……
附赠一组数据:
7
1 2
1 3
3 4
4 5
2 6
2 7
5
ans = 48
代码
#include <bits/stdc++.h> using namespace std; const int N=100005,mod=1e9+7; int read(){ int x=0; char ch=getchar(); while (!isdigit(ch)) ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+ch-48,ch=getchar(); return x; } int Pow(int x,int y){ int ans=1; for (;y;y>>=1,x=1LL*x*x%mod) if (y&1) ans=1LL*ans*x%mod; return ans; } struct Gragh{ static const int M=N*2; int cnt,y[M],nxt[M],fst[N]; void clear(){ cnt=0; memset(fst,0,sizeof fst); } void add(int a,int b){ y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt; } }g; int n; int fa[N],sz[N],top[N],Maxd[N],depth[N],p[N],ap[N]; vector <int> son[N]; bool cmp(int x,int y){ return Maxd[x]>Maxd[y]; } void dfs(int x,int pre,int d){ depth[x]=Maxd[x]=d,fa[x]=pre; son[x].clear(); for (int i=g.fst[x];i;i=g.nxt[i]) if (g.y[i]!=pre){ int y=g.y[i]; dfs(y,x,d+1); Maxd[x]=max(Maxd[x],Maxd[y]); son[x].push_back(y); } sort(son[x].begin(),son[x].end(),cmp); sz[x]=(int)son[x].size(); } int Time=0; void Get_Top(int x,int TOP){ top[x]=TOP; ap[p[x]=++Time]=x; if (!sz[x]) return; Get_Top(son[x][0],TOP); for (int i=1;i<sz[x];i++) Get_Top(son[x][i],son[x][i]); } struct Seg{ int v,add; }t[N<<2]; void build(int rt,int L,int R){ t[rt].v=0,t[rt].add=1; if (L==R) return; int mid=(L+R)>>1,ls=rt<<1,rs=ls|1; build(ls,L,mid); build(rs,mid+1,R); } void Times(int rt,int d){ t[rt].v=1LL*t[rt].v*d%mod; t[rt].add=1LL*t[rt].add*d%mod; } void pushdown(int rt){ int ls=rt<<1,rs=ls|1,&v=t[rt].add; if (v==1) return; Times(ls,v); Times(rs,v); v=1; } void update(int rt,int L,int R,int xL,int xR,int opt,int d){ if (L>xR||R<xL||xL>xR) return; if (xL<=L&&R<=xR){ if (opt==0) Times(rt,d); else t[rt].v=(t[rt].v+d)%mod; return; } pushdown(rt); int mid=(L+R)>>1,ls=rt<<1,rs=ls|1; update(ls,L,mid,xL,xR,opt,d); update(rs,mid+1,R,xL,xR,opt,d); t[rt].v=(t[ls].v+t[rs].v)%mod; } int query(int rt,int L,int R,int xL,int xR){ if (L>xR||R<xL||xL>xR) return 0; if (xL<=L&&R<=xR) return t[rt].v; pushdown(rt); int mid=(L+R)>>1,ls=rt<<1,rs=ls|1; return (query(ls,L,mid,xL,xR)+query(rs,mid+1,R,xL,xR))%mod; } void Prepare(){ dfs(1,0,0); Get_Top(1,1); } int D,addv[N],sv[N]; void DFS(int x){ update(1,1,n,p[x],p[x],1,1); if (sz[x]){ DFS(son[x][0]); update(1,1,n,p[x]+1,p[x]+min(D,Maxd[x]-depth[x]),0,2); for (int i=1;i<sz[x];i++){ int y=son[x][i],lim=Maxd[x]-depth[x]; int vy=min(D,Maxd[y]-depth[y]+1); DFS(y); int lastv=1; for (int j=0;j<=vy;j++) addv[j]=0; for (int j=1;j<=vy;j++) sv[j]=query(1,1,n,p[x],p[x]+min(j,D-j)); for (int j=1;j<=vy;j++){ int v=query(1,1,n,p[y]+j-1,p[y]+j-1); addv[j]=(1LL*(sv[j]+1)*v+addv[j])%mod; int k=D-j; if (k>j){ int inv=Pow(lastv,mod-2); update(1,1,n,p[x]+j+1,p[x]+min(k,lim),0,inv); lastv=(lastv+v)%mod; update(1,1,n,p[x]+j+1,p[x]+min(k,lim),0,lastv); } } for (int j=0;j<=vy;j++) update(1,1,n,p[x]+j,p[x]+j,1,addv[j]); for (int j=vy+1;j<=Maxd[y]-depth[y]+1;j++){ int v=query(1,1,n,p[y]+j-1,p[y]+j-1); update(1,1,n,p[x]+j,p[x]+j,1,v); } } } } int solve(int DD){ if (DD==0) return n; build(1,1,n); D=DD; DFS(1); return query(1,1,n,1,Maxd[1]+1); } int main(){ n=read(); g.clear(); for (int i=1;i<n;i++){ int a=read(),b=read(); g.add(a,b); g.add(b,a); } Prepare(); int DD=read(); printf("%d",(solve(DD)-solve(DD-1)+mod)%mod); return 0; }