测试地址:配对树
做法:本题需要用到线段树合并。
这题是NOI2018中间Mychael大佬给的一道神题,现在他已AFO,我也从蒟蒻变成了一个没那么弱的蒟蒻,于是在写这一题的同时,怀念一下和Mychael大佬一起度过的OI时光(虽然只是网上交流2333)。
首先,如果一个区间内点出现了多次,肯定是先自己结合掉次是最优的。那么剩下的一个点集,对于树上的一棵子树,有这样的结论:对于子树中的点,匹配点在子树之外的一定不超过个。因为如果有个点匹配点在子树之外,那么这两条路径肯定有重叠部分,这时把匹配关系替换一下,重叠部分消失,答案就更优了。于是进一步地,我们得到结论:一条边存在在最终的匹配路径中,当且仅当它较深的一个端点的子树中的点,在所选点集中的出现的次数是奇数。因为如果有偶数个点,按照上面的结论,它们在子树中自行匹配肯定是最优的,所以没必要用到这条边。
那么我们就可以转化求贡献的方式了。原本我们是对于每个区间,求匹配路径长度的总和。现在我们可以对于树上的每条边,求有多少区间使得这条边存在在匹配路径中,这就是这条边的贡献次数。于是也就是求,有多少长度为偶数的区间,使得子树中的点在区间中出现的次数是奇数。
现在假设我们求出了子树中的点在序列中出现的所有位置,很自然地想到先求一个模意义下的前缀和,然后在包含前缀长度为的前缀和(也就是一个)的基础上,满足条件的长度为偶数的区间数目,应该等于:将前缀和按前缀长度的奇偶分类,分别把同一类中,前缀和为的前缀数目和前缀和为的前缀数目乘起来(也就是求前缀和不同的前缀对数),再相加。这个应该比较好理解,就不详细说明了。
可是子树中点在序列中出现的所有位置这个信息,是随着子树越来越大而逐渐合并上去的……慢着,既然位置的信息可以线段树合并,那么上面要求的贡献次数可不可以合并呢?当然是可以的。
线段树作为序列分治形态的表示,要用线段树处理,我们就只需要想到怎么在分治中合并信息即可。因为要保证线段树合并的复杂度,所以我们显然不能直接存储前缀和的信息,只能用位置的信息推算出前缀和的信息。我们可以存储两个数组,表示一个区间内,长度在奇数或偶数位置的前缀和为的前缀数量。注意,这里的前缀就是指区间内的前缀。所以我们在合并时,需要考虑左半区间的长度的奇偶性,还要考虑左半区间的前缀和的奇偶性。于是我们就解决了合并的问题。
那么我们就可以用在树上线段树合并的方法来完成这一题了,时间复杂度为。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
int n,m,first[100010]={0},tot=0;
int totp=0,rt[100010]={0},ch[2000010][2]={0};
ll cnt[2000010][2]={0},ans[2000010]={0},totans;
bool sum[2000010]={0};
struct edge
{
int v,next;
ll w;
}e[200010];
void insert(int a,int b,ll w)
{
e[++tot].v=b;
e[tot].next=first[a];
e[tot].w=w;
first[a]=tot;
}
void pushup(int v,int l,int r)
{
int mid=(l+r)>>1,L=ch[v][0],R=ch[v][1];
ll cnt01=cnt[R][0],cnt11=cnt[R][1];
ll cnt00=((r-mid+1)>>1)-cnt01,cnt10=((r-mid)>>1)-cnt11;
if (sum[L]) swap(cnt00,cnt01),swap(cnt10,cnt11);
if ((mid-l+1)%2==1) swap(cnt00,cnt10),swap(cnt01,cnt11);
cnt[v][0]=(cnt[L][0]+cnt01)%mod;
cnt[v][1]=(cnt[L][1]+cnt11)%mod;
ans[v]=(ans[L]+ans[R])%mod;
ans[v]=(ans[v]+cnt[L][0]*cnt00+(ll)(((mid-l+2)>>1)-cnt[L][0])*cnt01)%mod;
ans[v]=(ans[v]+cnt[L][1]*cnt10+(ll)(((mid-l+1)>>1)-cnt[L][1])*cnt11)%mod;
sum[v]=sum[L]^sum[R];
}
void insert(int &v,int l,int r,int x)
{
if (!v) v=++totp;
if (l==r)
{
sum[v]=cnt[v][0]=1;
return;
}
int mid=(l+r)>>1;
if (x<=mid) insert(ch[v][0],l,mid,x);
else insert(ch[v][1],mid+1,r,x);
pushup(v,l,r);
}
int merge(int x,int y,int l,int r)
{
if (!x) return y;
if (!y) return x;
int mid=(l+r)>>1;
ch[x][0]=merge(ch[x][0],ch[y][0],l,mid);
ch[x][1]=merge(ch[x][1],ch[y][1],mid+1,r);
pushup(x,l,r);
return x;
}
void solve(int v,int fa)
{
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa)
{
solve(e[i].v,v);
totans=(totans+e[i].w*ans[rt[e[i].v]])%mod;
rt[v]=merge(rt[v],rt[e[i].v],0,m);
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int u,v;
ll w;
scanf("%d%d%lld",&u,&v,&w);
insert(u,v,w),insert(v,u,w);
}
for(int i=1;i<=m;i++)
{
int x;
scanf("%d",&x);
insert(rt[x],0,m,i);
}
totans=0;
solve(1,0);
printf("%lld",totans);
return 0;
}