题目大意:
给你一个$n(nleq300000)$个结点的以$1$为根的树,结点有黑白两种颜色,每个点初始权值为$0$。进行以下2种共$m(mleq300000)$次操作:
1.给定结点$u$,对于所有的黑点$v$,令${
m LCA}(u,v)$的权值加上$v$;
2.改变$u$的颜色。
思路:
考虑没有操作2的情况,我们可以记录所有结点被询问的次数和所有黑点的编号。最后统计答案时,只需要一个树形DP,求出每个子树内被询问的次数$cnt$和所有黑点的编号之和$sum$即可。$w[x]=sum cnt[y] imes(sum[x]-sum[y])$。
考虑加上操作2,本质上就是多了一个表示时间的维度,考虑使用线段树降维,树形DP时只需要线段树合并更新答案即可。时间复杂度$O(nlog n)$。
1 #include<list> 2 #include<cstdio> 3 #include<cctype> 4 typedef long long int64; 5 inline int getint() { 6 register char ch; 7 while(!isdigit(ch=getchar())); 8 register int x=ch^'0'; 9 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 10 return x; 11 } 12 const int N=300001; 13 bool s[N]; 14 int64 w[N]; 15 int m,last[N]; 16 std::list<int> e[N]; 17 inline void add_edge(const int &u,const int &v) { 18 e[u].push_back(v); 19 e[v].push_back(u); 20 } 21 class SegmentTree { 22 private: 23 struct Node { 24 int cnt; 25 int64 sum; 26 Node *left,*right; 27 }; 28 public: 29 Node *root[N]; 30 void modify(Node *&p,const int &b,const int &e,const int &l,const int &r,const int &x,const int &y) { 31 p=p?:new Node(); 32 p->cnt+=x; 33 if(b==l&&e==r) { 34 p->sum+=y; 35 return; 36 } 37 const int mid=(b+e)>>1; 38 if(l<=mid) modify(p->left,b,mid,l,std::min(mid,r),x,y); 39 if(r>mid) modify(p->right,mid+1,e,std::max(mid+1,l),r,x,y); 40 } 41 void merge(Node *&p,Node *const &q,const int &b,const int &e,const int &id) { 42 if(!p||!q) { 43 p=p?:q; 44 return; 45 } 46 w[id]+=p->cnt*q->sum+q->cnt*p->sum; 47 p->cnt+=q->cnt; 48 p->sum+=q->sum; 49 const int mid=(b+e)>>1; 50 merge(p->left,q->left,b,mid,id); 51 merge(p->right,q->right,mid+1,e,id); 52 delete q; 53 } 54 }; 55 SegmentTree t; 56 void dfs(const int &x,const int &par) { 57 for(std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) { 58 const int &y=*i; 59 if(y==par) continue; 60 dfs(y,x); 61 t.merge(t.root[x],t.root[y],1,m,x); 62 } 63 } 64 int main() { 65 const int n=getint();m=getint(); 66 for(register int i=1;i<=n;i++) { 67 last[i]=s[i]=getint(); 68 } 69 for(register int i=1;i<n;i++) { 70 add_edge(getint(),getint()); 71 } 72 for(register int i=1;i<=m;i++) { 73 const int opt=getint(),u=getint(); 74 if(opt==1) { 75 t.modify(t.root[u],1,m,i,i,1,0); 76 if(s[u]) w[u]+=u; 77 } 78 if(opt==2) { 79 if(s[u]^=1) { 80 last[u]=i; 81 } else { 82 if(i!=1) t.modify(t.root[u],1,m,last[u],i-1,0,u); 83 } 84 } 85 } 86 for(register int i=1;i<=n;i++) { 87 if(s[i]) t.modify(t.root[i],1,m,last[i],m,0,i); 88 } 89 dfs(1,0); 90 for(register int i=1;i<=n;i++) printf("%lld ",w[i]); 91 return 0; 92 }