题目大意:
有 (n) 个三重集合 { (pos) (,) (a) (,) (b) } , (pos) 是唯一的,它表示集合编号 ( (1) (≤) (pos) (≤) (n) )。
设 (a)(pos) 、(b)(pos) 分别表示第 (pos) 号集合中的元素 (a) 、(b)。
在这些集合中,给定 (n-1) 个偏序对 ( (i) (,) (j) ) ( 对于所有偏序对中,右项中 (j) 的值只会出现一次 )
表示一组大小关系 (a)(i) (>) (a)(j) ( (i) (≠) (j) (∈) (pos) ),且对于每个集合中的 (b) 都会有一个初始值。
定义 (O) ( (i) (,) (j) (,) (s) ) 表示(:)
在给定的 (n-1) 个大小关系下,集合编号 (t) 满足 ( (a)(t) (>) (a)(i) (,) (a)(t) (>)(a)(j) ),将所有满足的集合 (t) 中的 元素 (b)(t) 都加上 (s) 。
在执行 (k) 个 (O) ( (i) (,) (j) (,) (s) ) 之后,求出所有集合中的 (b) 之和。
保证给定的关系合法,不会出现 (a)(i) (>) (a)(j) 且 (a)(i) (<) (a)(j) 。
分析:
1、通过偏序关系,可以想到拓扑排序出一个大小关系的全序。
2、对于两个集合 (i) 、(j) ,若 (a)(i) (>) (a)(j) ,考虑连出一条有向边 (i) (→) (j) ,由于不会出现 (i) (→) (j) 且 (i) (←) (j) 的情况,故将这 (n-1) 个偏序都这样处理,那么它们连接起来一定不会有环,且必是一个能被拓扑排序的有向图。
3、由于偏序对中右项的 (j) 只会出现一次,那么相当于这样的有向图中的每一个点的入度至多为 1 ,由于只有 (n-1) 条有向边,这样连起来的图形必然是一棵完整的树。
4、开始考虑 (O) ( (i) (,) (j) (,) (s) ),对于集合编号 (t) 满足 ( (a)(t) (>) (a)(i) (,) (a)(t) (>)(a)(j) ) ,那么 (t) 必是 (i) 与 (j) 的共同父亲或祖先。这样就相当于求得 (LCA) (() (i) (,) (j) ()) ,然后 (lca) 及 (lca) 的父亲和祖先节点的 (b) 值 加上 (s) 即可,这里差分标记它或者记录 (lca) 的深度就可以(O(1)) 处理。
5、而对于 (LCA) (() (i) (,) (j) ()) 等于 (i) 或者 (j) 时要特殊考虑,因为 (lca) 上的 (a) 关系并不满足这两个关系 (a)(lca) (>) (a)(i) (,) (a)(lca) (>)(a)(j),对于这种情况,应该跳过 (lca) ,再向 (lca) 上一个点处理即可。
比如 (4) 和 (2) 的 (lca) 为 (2) ,只有 (1) 才满足 (a)(1) (>) (a)(2) 且 (a)(1) (>)(a)(4) 的关系,而 作为 (lca) 的 (2) 并不满足,故需要跳过它。
树上差分:
对 (lca) 或 (lca) 上一个点 的差分值 += (s) 即可,要注意的是,(pre[]) (tarjan 求 lca 中的并查集数组)并不是一直指向父亲节点的(还没有往上并),所以要开一个数组记录父亲节点。
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
typedef long long ll;
const double PI=acos(-1.0);
const ll mod = 1e9 + 7;
const int maxn = 100008;
const int inf = 0x3f3f3f3f;
const ll INF = 1e18 + 10;
const double eps = 1e-6;
using namespace std;
int n,k,cnt,tot,root;
int head[maxn],qhead[maxn],in[maxn],pre[maxn],fa[maxn];
ll ans,b[maxn],f[maxn];
bool vis[maxn],flag[maxn<<1];
struct Query{
int to;
ll val;
int next;
}q[maxn<<1];
struct Edge{
int to;
int next;
}edge[maxn];
int find(int x){
if(pre[x]==x) return x;
return pre[x]=find(pre[x]);
}
inline void add(int u,int v){
edge[++cnt].to=v;
edge[cnt].next=head[u];
head[u]=cnt;
return;
}
inline void qadd(int u,int v,ll w){
q[++tot].to=v;
q[tot].val=w;
q[tot].next=qhead[u];
qhead[u]=tot;
return;
}
void tarjan(int u){
vis[u]=true;
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;
fa[v]=u;
tarjan(v);
pre[v]=u;
}
for(int i=qhead[u];~i;i=q[i].next){
int v=q[i].to;
if(vis[v]&&!flag[i]){
flag[i]=flag[i^1]=true;
int lca=find(v);
if(lca!=u&&lca!=v) f[lca]+=q[i].val;
else f[fa[lca]]+=q[i].val;
}
}
return;
}
void dfs(int u){
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;
dfs(v);
f[u]+=f[v];
}
ans+=f[u];
return;
}
int main()
{
tot=cnt=-1;
memset(head,-1,sizeof(head));
memset(qhead,-1,sizeof(qhead));
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++) {
scanf("%lld",&b[i]);
ans+=b[i];
pre[i]=i;
}
int A,B;
ll C;
for(int i=1;i<=n-1;i++){
scanf("%d%d",&A,&B);
add(A,B),in[B]++;
}
for(int i=1;i<=n;i++) {
if(!in[i]) {
root=i;
break;
}
}
for(int i=1;i<=k;i++){
scanf("%d%d%lld",&A,&B,&C);
qadd(A,B,C),qadd(B,A,C);
}
tarjan(root);
dfs(root);
printf("%lld
", ans);
}
或者直接标记各个点的深度,即可以直接保存该点的父亲或祖先节点的个数来统计答案了。
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
typedef long long ll;
const double PI=acos(-1.0);
const ll mod = 1e9 + 7;
const int maxn = 100008;
const int inf = 0x3f3f3f3f;
const ll INF = 1e18 + 10;
const double eps = 1e-6;
using namespace std;
int n,k,cnt,tot,root;
int head[maxn],qhead[maxn],in[maxn],pre[maxn],fa[maxn],deep[maxn];
ll ans,b[maxn],f[maxn];
bool vis[maxn],flag[maxn<<1];
struct Query{
int to;
ll val;
int next;
}q[maxn<<1];
struct Edge{
int to;
int next;
}edge[maxn];
int find(int x){
if(pre[x]==x) return x;
return pre[x]=find(pre[x]);
}
inline void add(int u,int v){
edge[++cnt].to=v;
edge[cnt].next=head[u];
head[u]=cnt;
return;
}
inline void qadd(int u,int v,ll w){
q[++tot].to=v;
q[tot].val=w;
q[tot].next=qhead[u];
qhead[u]=tot;
return;
}
void tarjan(int u){
vis[u]=true;
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;
fa[v]=u;
deep[v]=deep[u]+1;
tarjan(v);
pre[v]=u;
}
for(int i=qhead[u];~i;i=q[i].next){
int v=q[i].to;
if(vis[v]&&!flag[i]){
flag[i]=flag[i^1]=true;
int lca=find(v);
if(lca!=u&&lca!=v) ans+=deep[lca]*q[i].val;
else ans+=deep[fa[lca]]*q[i].val;
}
}
return;
}
int main()
{
tot=cnt=-1;
memset(head,-1,sizeof(head));
memset(qhead,-1,sizeof(qhead));
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++) {
scanf("%lld",&b[i]);
ans+=b[i];
pre[i]=i;
}
int A,B;
ll C;
for(int i=1;i<=n-1;i++){
scanf("%d%d",&A,&B);
add(A,B),in[B]++;
}
for(int i=1;i<=n;i++) {
if(!in[i]) {
root=i;
break;
}
}
for(int i=1;i<=k;i++){
scanf("%d%d%lld",&A,&B,&C);
qadd(A,B,C),qadd(B,A,C);
}
deep[root]=1;
tarjan(root);
printf("%lld
", ans);
}