Problem Sometimes Naive
题目大意
给你一棵n个节点的树,有点权。
要求支持两种操作:
操作1:更改某个节点的权值。
操作2:给定u,v, 求 Σw[i][j] i , j 为任意两点且i到j的路径与u到v的路径相交。
解题分析
容易发现对于一个询问,答案为总点权和的平方 减去 去掉u--v这条链后各个子树的点权和的平方的和。
开两棵线段树,tag1记录点权和,tag2记录某点的所有轻链子树的点权和的平方的和。
每次沿着重链往上走时,直接加上这条重链的所有点的tag2和,若有重儿子则直接用tag1计算。由于该条重链必定为其父亲的轻链,故为防止计算重复,还需减去该重链所有点的tag1平方和。
最后爬到同一颗重链后,还需计算重链上方所有点的贡献。
参考程序
1 #include <map> 2 #include <set> 3 #include <stack> 4 #include <queue> 5 #include <cmath> 6 #include <ctime> 7 #include <string> 8 #include <vector> 9 #include <cstdio> 10 #include <cstdlib> 11 #include <cstring> 12 #include <cassert> 13 #include <iostream> 14 #include <algorithm> 15 #pragma comment(linker,"/STACK:102400000,102400000") 16 using namespace std; 17 18 #define V 100008 19 #define E 200008 20 #define LL long long 21 #define lson l,m,rt<<1 22 #define rson m+1,r,rt<<1|1 23 #define clr(x,v) memset(x,v,sizeof(x)); 24 #define rep(x,y,z) for (int x=y;x<=z;x++) 25 #define repd(x,y,z) for (int x=y;x>=z;x--) 26 const int mo = 1000000007; 27 const int inf = 0x3f3f3f3f; 28 const int INF = 2000000000; 29 /**************************************************************************/ 30 int n,m,tot; 31 int val[V]; 32 int size[V],fa[V],w[V],top[V],rk[V],dep[V],son[V]; 33 34 struct line{ 35 int u,v,nt; 36 line(int u=0,int v=0,int nt=0):u(u),v(v),nt(nt){} 37 }eg[E]; 38 int lt[V],sum; 39 40 void add(int u,int v){ 41 eg[++sum]=line(u,v,lt[u]); lt[u]=sum; 42 } 43 44 struct Segment_Tree{ 45 LL sum[V<<2]; 46 void clear(){ 47 clr(sum,0); 48 } 49 void pushup(int rt){ 50 sum[rt] = (sum[rt<<1] + sum[rt<<1|1]) % mo; 51 } 52 void update(int x,int val,int l,int r,int rt){ 53 if (l==r){ 54 sum[rt] += val; 55 sum[rt] = sum[rt] % mo; 56 return; 57 } 58 int m=(l+r)>>1; 59 if (x <= m) update(x,val,lson); 60 if (m < x) update(x,val,rson); 61 pushup(rt); 62 } 63 LL query(int L,int R,int l,int r,int rt){ 64 if (L<=l && r<=R){ 65 return sum[rt]; 66 } 67 int m=(l+r)>>1; 68 LL res=0; 69 if (L <= m) res += query(L,R,lson); 70 if (m < R) res += query(L,R,rson); 71 res = res % mo; 72 return res; 73 } 74 75 }Ts,Td; 76 void init(){ 77 clr(lt,0); sum=1; tot=0; 78 Ts.clear(); 79 Td.clear(); 80 } 81 void dfs_1(int u){ 82 dep[u]=dep[fa[u]]+1; size[u]=1; son[u]=0; 83 for (int i=lt[u];i;i=eg[i].nt){ 84 int v=eg[i].v; 85 if (v==fa[u]) continue; 86 fa[v]=u; 87 dfs_1(v); 88 if (size[v]>size[son[u]]) son[u]=v; 89 size[u]+=size[v]; 90 } 91 } 92 93 void dfs_2(int u,int tp){ 94 top[u]=tp; w[u]=++tot; rk[tot]=u;; 95 if (son[u]) dfs_2(son[u],tp); 96 for (int i=lt[u];i;i=eg[i].nt){ 97 int v=eg[i].v; 98 if (v==fa[u]||v==son[u]) continue; 99 dfs_2(v,v); 100 } 101 } 102 103 int sqr(int x){return 1ll*x*x %mo;} 104 void update(int x,int v){ 105 int u=top[x]; 106 while (fa[u]){ 107 LL sum=Ts.query(w[u],w[u]+size[u]-1,1,n,1); 108 Td.update(w[fa[u]],(sqr(val[x]-v)-sum*2*(val[x]-v) % mo)%mo,1,n,1); 109 u=top[fa[u]]; 110 } 111 Ts.update(w[x],v-val[x],1,n,1); 112 val[x]=v; 113 } 114 LL query(int x,int y){ 115 LL res=0; 116 while (top[x]!=top[y]){ 117 if (dep[top[x]]<dep[top[y]]) swap(x,y); 118 res += Td.query(w[top[x]],w[x],1,n,1); 119 res = res % mo; 120 121 if (son[x]){ 122 LL sum=Ts.query(w[son[x]],w[son[x]]+size[son[x]]-1,1,n,1); 123 res = res + sum*sum; 124 res = res % mo; 125 } 126 LL sum=Ts.query(w[top[x]],w[top[x]]+size[top[x]]-1,1,n,1); 127 128 res = res - sum*sum; 129 res = res % mo; 130 while (res<0) res+=mo; 131 x=fa[top[x]]; 132 } 133 if (dep[x]>dep[y]) swap(x,y); 134 res += Td.query(w[x],w[y],1,n,1); 135 res = res % mo; 136 if (son[y]){ 137 LL sum=Ts.query(w[son[y]],w[son[y]]+size[son[y]]-1,1,n,1); 138 res = res + sum*sum; 139 res = res % mo; 140 } 141 if (fa[x]){ 142 LL sum=Ts.query(1,n,1,n,1)-Ts.query(w[x],w[x]+size[x]-1,1,n,1); 143 res = res + sum*sum; 144 res = res % mo; 145 } 146 return res; 147 } 148 149 int main(){ 150 while (~scanf("%d %d",&n,&m)){ 151 init(); 152 rep(i,1,n) scanf("%d",&val[i]); 153 rep(i,2,n){ 154 int u,v; 155 scanf("%d %d",&u,&v); 156 add(u,v); add(v,u); 157 } 158 dfs_1(1); 159 dfs_2(1,1); 160 rep(i,1,n){ 161 int k=val[i]; 162 val[i]=0; 163 update(i,k); 164 } 165 while (m--){ 166 int x,u,v; 167 scanf("%d %d %d",&x,&u,&v); 168 if (x==1) update(u,v); 169 else { 170 LL sum=Ts.query(w[1],w[1]+size[1]-1,1,n,1); 171 sum = sum * sum; 172 sum = sum - query(u,v); 173 sum = sum % mo; 174 while (sum<0) sum+=mo; 175 printf("%lld ",sum); 176 } 177 } 178 } 179 }