题意:
给定一棵n个点的树,每条边有权值,树上两点路径长度定义为边权和。给定一个元素在[1,n]的长为m的序列,求出对于每个长为偶数的区间,区间中的数字两两匹配后每对点的路径长度之和最小值。输出所有长为偶数区间的这个最小值之和。
$n,mleq 10^5.$
题解:
转化很巧妙。
直接算很不好算,考虑计算每条边的贡献。一个性质是:假定在区间中的元素集合为S,对于某一条边分成的两个子树,如果两个子树中出现在S中的元素个数均为奇数,则这条边有1的贡献,否则没有贡献。
证明很简单,对于都为偶数的情况考虑反证,如果存在路径经过这条边,则一定至少两(偶数)条,那么可以把这两条路径都删去这条边得到更优解。对于都为奇数的情况,一定至少存在一条经过这条边的路径,去掉这条路径后则转化为了偶数的情况。证毕。
那么原题转化为:对于每个子树,如果将子树中的元素在序列中标记为1,那么要求的就是这个01串中有多少长为偶数的区间内1的个数为奇数。
暴力算是$mathcal{O}(nm)$的。我们考虑用线段树维护01串,记录区间内1的个数,区间内位置为奇/偶,前缀和mod2为奇/偶的下标数量。线段树合并即可。复杂度$mathcal{O}(nlog m)$。
code:
1 #include<bits/stdc++.h> 2 #define rep(i,x,y) for (int i=(x);i<=(y);i++) 3 #define ll long long 4 #define inf 1000000001 5 #define y1 y1___ 6 using namespace std; 7 ll read(){ 8 char ch=getchar();ll x=0;int op=1; 9 for (;!isdigit(ch);ch=getchar()) if (ch=='-') op=-1; 10 for (;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0'; 11 return x*op; 12 } 13 #define N 100005 14 #define M 2000005 15 #define mod 998244353 16 int n,m,cnt,tot,ans,head[N],rt[N],ls[M],rs[M],sum[M],a[M][2][2]; 17 struct edge{int to,nxt,v;}e[N<<1]; 18 void adde(int x,int y,int z){ 19 e[++cnt].to=y;e[cnt].nxt=head[x];head[x]=cnt; 20 e[cnt].v=z; 21 } 22 void up(int k,int l,int r){ 23 sum[k]=0; 24 if (ls[k]) sum[k]+=sum[ls[k]]; 25 if (rs[k]) sum[k]+=sum[rs[k]]; 26 int x=ls[k]?sum[ls[k]]&1:0; 27 rep (i,0,1) rep (j,0,1){ 28 a[k][i][j]=0; 29 if (ls[k]) a[k][i][j]+=a[ls[k]][i][j]; 30 if (rs[k]) a[k][i][j]+=a[rs[k]][i^x][j]; 31 } 32 int mid=l+r>>1;//注意这两句别忘 33 if (!ls[k]) a[k][0][0]+=mid/2-(l-1)/2,a[k][0][1]+=(mid+1)/2-l/2; 34 if (!rs[k]) a[k][x][0]+=r/2-mid/2,a[k][x][1]+=(r+1)/2-(mid+1)/2; 35 } 36 void ins(int &k,int l,int r,int x){ 37 if (!k){//注意赋初始值 38 k=++tot; 39 a[k][0][0]=r/2-(l-1)/2; 40 a[k][0][1]=(r+1)/2-l/2; 41 } 42 if (l==r){sum[k]++;return;} 43 int mid=l+r>>1; 44 if (x<=mid) ins(ls[k],l,mid,x);else ins(rs[k],mid+1,r,x); 45 up(k,l,r); 46 } 47 int merge(int x,int y,int l,int r){ 48 if (!x||!y) return x|y; 49 int mid=l+r>>1; 50 ls[x]=merge(ls[x],ls[y],l,mid); 51 rs[x]=merge(rs[x],rs[y],mid+1,r); 52 up(x,l,r); 53 return x; 54 } 55 void upd(int &x,int y){x+=y;x-=x>=mod?mod:0;} 56 void dfs(int u,int pr){ 57 for (int i=head[u];i;i=e[i].nxt) if (e[i].to!=pr){ 58 int v=e[i].to; 59 dfs(v,u); 60 upd(ans,((ll)a[rt[v]][0][0]*a[rt[v]][1][0]%mod+(ll)a[rt[v]][0][1]*a[rt[v]][1][1]%mod)%mod*e[i].v%mod); 61 rt[u]=merge(rt[u],rt[v],1,m+1); 62 } 63 } 64 int main(){ 65 n=read(),m=read(); 66 rep (i,1,n-1){ 67 int x=read(),y=read(),z=read(); 68 adde(x,y,z);adde(y,x,z); 69 } 70 rep (i,1,m) ins(rt[read()],1,m+1,i); 71 dfs(1,0); 72 cout<<ans<<' '; 73 return 0; 74 }
易错:
注意初始情况不是0,需要处理。