51Nod1709 复杂度分析
思路
一道很有意思的题。
我们考虑按位累计贡献。记录(dp[i][j])为上一步倍增跳了(2^j)步,跳到了(i)点的所有状态"1"的个数和。
可以理解为,对于一个儿子u
和他的祖先v
的深度差,可以表示为一个二进制数。我们倍增的将u
跳到v
,每次跳的长度一定比上次小,且一定是(2^k)。
对于一个点i
,上一步跳了(2^j)步,那么就(forall k,k<j),dp[fa][k]+=dp[i][j]+num[i][j]
。其中fa
代表的是i
第(2^k)个祖先,num
代表的是某个状态有多少种方案。那么这句话的意思就是,对于所有可以由((i,j))这个状态转移到的状态,都把他们的dp
加上((原状态dp值+原状态方案数那么多个“1”))。同时num[fa][k]+=num[i][k]
,ans+=dp[fa][k]*(size[fa]-size[from]-1)
,from
指的是这一次是从fa
的哪个亲儿子跳上来的。
那么连初始化都不需要了,直接多考虑一种情况跳跃即可。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 105000
#define ll long long
using namespace std;
int dp[maxn][23],f[maxn][23],num[maxn][23],size[maxn],dep[maxn],mem[maxn][23];
ll ans;
struct gg{
int u,v,next;
}side[maxn*2];
int head[maxn],cnt,n;
void insert(int u,int v){
struct gg add={u,v,head[u]};side[++cnt]=add;head[u]=cnt;
}
void dfs1(int now,int fa){
f[now][0]=fa;size[now]=1;dep[now]=dep[fa]+1;
for(int i=1;i<=22;i++)f[now][i]=f[f[now][i-1]][i-1];
for(int i=head[now];i;i=side[i].next){
int v=side[i].v;if(v==fa)continue;
dfs1(v,now);size[now]+=size[v];
}
}
int jump(int tar,int now){
for(int i=22;i>=0;i--)if(dep[f[now][i]]>dep[tar])now=f[now][i];
return now;
}
void init(){
for(int i=1;i<=n;i++){
for(int j=0;j<=22;j++){
int tar=f[i][j];if(!tar)continue;
mem[i][j]=jump(tar,i);
}
}
}
void dfs2(int now,int fa){
for(int i=head[now];i;i=side[i].next){
int v=side[i].v;if(v==fa)continue;
dfs2(v,now);
}
for(int i=1;i<=22;i++){//上次跳了2^i步
for(int j=0;j<i;j++){//这次跳了2^j步
int tar=f[now][j];//tar为这次跳跃的目的地
if(!tar)continue;
dp[tar][j]+=dp[now][i]+num[now][i];
num[tar][j]+=num[now][i];
ans+=(dp[now][i]+num[now][i])*(size[tar]-size[mem[now][j]]);
}
}
for(int j=0;j<=22;j++){//这次跳了2^j步
int tar=f[now][j];//tar为这次跳跃的目的地
if(!tar)continue;
dp[tar][j]+=1;
num[tar][j]+=1;
ans+=(size[tar]-size[mem[now][j]]);
}
}
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int _read(){
char ch=nc();int sum=0;
while(!(ch>='0'&&ch<='9'))ch=nc();
while(ch>='0'&&ch<='9')sum=sum*10+ch-48,ch=nc();
return sum;
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int u,v;u=_read();v=_read();insert(u,v);insert(v,u);
}
dfs1(1,0);
init();
dfs2(1,0);
cout<<ans;
}