测试点1,2:
直接暴力,(n^2)预处理然后(O(1))查询。
代码:
#include <cstdio>
const int maxn=5e3+10;
struct E{
int to;
int next;
}ed[maxn*2];
int head[maxn];
int tot;
void J(int a,int b){
ed[++tot].to=b;
ed[tot].next=head[a];
head[a]=tot;
}
int a[maxn];
long long jl[maxn][maxn];
void get(int x,int fa,long long ha,int dis,int id){
jl[id][x]=ha+(dis|a[x]);
for(int i=head[x];i;i=ed[i].next){
if(ed[i].to==fa)
continue;
get(ed[i].to,x,ha+(dis|a[x]),dis+1,id);
}
}
int main(){
int n,q;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
int u,v;
for(int i=1;i<=n-1;i++){
scanf("%d%d",&u,&v);
J(u,v);
J(v,u);
}
for(int i=1;i<=n;i++)
get(i,0,0,0,i);
for(int i=1;i<=q;i++){
scanf("%d%d",&u,&v);
printf("%lld
",jl[u][v]);
}
return 0;
}
测试点4:
只需要知道(dist(u,w))为奇数并且(a_w=1)的(w)有多少个,差分(dp_x=dp_{fa_{fa_x}}+a_{fa_x})
然后统计答案的时候根据lca与u,u与v之间距离的奇偶判断该如何计算。
也许更简单的实现,建两颗树,深度为奇数的和深度为偶数的树(为了方便给1一个父亲),然后如果u,v不属于同一树,(v=fa_v),然后直接算u到v的1的个数的和(好像还要判断一下lca算不算)。然后加上(dist(u,v)*(dist(u,v)+1)/2)(u指原来的u)
测试点8:
随机?直接暴力跳然后统计答案。
正解
其实直接考虑贡献即可,然后倍增,不过处理起来还是比较妙的。
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn=3e5+10;
struct E{
int to;
int next;
}ed[maxn*2];
int head[maxn];
int tot;
void J(int a,int b){
ed[++tot].to=b;
ed[tot].next=head[a];
head[a]=tot;
}
int t[maxn][21];
int cnt[maxn][21];
long long sum1[maxn][21];
long long sum2[maxn][21];
int deep[maxn];
int a[maxn];
void Dfs(int x,int fa){
deep[x]=deep[fa]+1;
for(int i=0;i<=20;i++)
cnt[x][i]=cnt[fa][i]+(!((1<<i)&a[x]));
t[x][0]=fa;
sum1[x][0]=sum2[x][0]=a[x];
for(int i=1;i<=20;i++){
t[x][i]=t[t[x][i-1]][i-1];
sum1[x][i]=sum1[x][i-1]+sum1[t[x][i-1]][i-1]+1ll*(1<<(i-1))*(cnt[t[x][i-1]][i-1]-cnt[t[x][i]][i-1]);
sum2[x][i]=sum2[x][i-1]+sum2[t[x][i-1]][i-1]+1ll*(1<<(i-1))*(cnt[x][i-1]-cnt[t[x][i-1]][i-1]);
}
for(int i=head[x];i;i=ed[i].next){
if(ed[i].to==fa)
continue;
Dfs(ed[i].to,x);
}
}
int lca(int x,int y){
if(deep[x]<deep[y])
swap(x,y);
int k=deep[x]-deep[y];
int id=0;
while(k){
if(k&1)
x=t[x][id];
id++;
k>>=1;
}
if(x==y)
return x;
for(int i=20;i>=0;i--)
if(t[x][i]!=t[y][i]){
x=t[x][i];
y=t[y][i];
}
return t[x][0];
}
long long query1(int x,int y){
long long ans=0;
int len=deep[x]-deep[y];
for(int i=20;i>=0;i--)
if(len&(1<<i)){
ans+=sum1[x][i];
x=t[x][i];
ans+=1ll*(1<<i)*(cnt[x][i]-cnt[y][i]);
}
return ans;
}
long long query2(int x,int len){
long long ans=0;
int now=x;
for(int i=0;i<=20;i++)
if(len&(1<<i)){
ans+=sum2[now][i];
ans+=1ll*(1<<i)*(cnt[x][i]-cnt[now][i]);
now=t[now][i];
}
return ans;
}
int main(){
freopen("C.in","r",stdin);
freopen("C.out","w",stdout);
int n,q;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
int u,v;
for(int i=1;i<=n-1;i++){
scanf("%d%d",&u,&v);
J(u,v);
J(v,u);
}
Dfs(1,0);
for(int i=1;i<=q;i++){
scanf("%d%d",&u,&v);
int lc=lca(u,v);
printf("%lld
",query1(u,t[lc][0])+query2(v,deep[u]+deep[v]-deep[lc]*2+1)-query2(lc,deep[u]-deep[lc]+1));
}
return 0;
}