转载自:http://blog.csdn.net/qq_18455665/article/details/50989113
前言
- 首先说说出处:
- 清华大学 张昆玮(zkw) - ppt 《统计的力量》
- 本文
(辣鸡)编辑:BeiYu - 写这篇博客的原因:
1.zkw线段树非递归,效率高,代码短
2.网上关于zkw线段树的讲解实在是太少了
3.个人感觉很实用
更新日志
- 20160327-Part 1(zkw线段树的建立)
- 20160329-Part 2(单点操作)
- 20160329-Part 3(区间操作)
Part 1
来说说它的构造
线段树的堆式储存
我们来转成二进制看看
小学生问题:找规律
规律是很显然的
- 一个节点的父节点是这个数左移1,这个位运算就是低位舍弃,所有数字左移一位
- 一个节点的子节点是这个数右移1,是左节点,右移1+1是右节点
- 同一层的节点是依次递增的,第n层有2^(n-1)个节点
- 最后一层有多少节点,值域就是多少(这个很重要)
有了这些规律就可以开始着手建树了
- 查询区间[1,n]
最后一层不是2的次幂怎么办?
开到2的次幂!后面的空间我不要了!就是这么任性!
Build函数就这么出来了!找到不小于n的2的次幂
直接输入叶节点的信息
int n,M,q;int d[N<<1]; inline void Build(int n){ for(M=1;M<n;M<<=1); for(int i=M+1;i<=M+n;i++) d[i]=in(); }
建完了?当然没有!父节点还都是空的呢!
维护父节点信息?
倒叙访问,每个节点访问的时候它的子节点已经处理过辣!
- 维护区间和?
for(int i=M-1;i;--i) d[i]=d[i<<1]+d[i<<1|1];
- 维护最大值?
for(int i=M-1;i;--i) d[i]=max(d[i<<1],d[i<<1|1]);
- 维护最小值?
for(int i=M-1;i;--i) d[i]=min(d[i<<1],d[i<<1|1]);
这样就构造出了一颗二叉树,也就是zkw线段树了!
如果你是压行选手的话(比如我),建树的代码只需要两行。
是不是特别Easy!
新技能Get√
Part 2
单点操作
- 单点修改
void Change(int x,int v){ d[M+x]+=v; }
只是这么简单?当然不是,跟线段树一样,我们要更新它的父节点!
void Change(int x,int v){ d[x=M+x]+=v; while(x) d[x>>=1]=d[x<<1]+d[x<<1|1]; }
没了?没了。
- 单点查询(差分思想,后面会用到)
把d维护的值修改一下,变成维护它与父节点的差值(为后面的RMQ问题做准备)
建树的过程就要修改一下咯!
void Build(int n){ for(M=1;M<=n+1;M<<=1);for(int i=M+1;i<=M+n;i++) d[i]=in(); for(int i=M-1;i;--i) d[i]=min(d[i<<1],d[i<<1|1]),d[i<<1]-=d[i],d[i<<1|1]-=d[i]; }
在当前情况下的查询
void Sum(int x,int res=0){ while(x) res+=d[x],x>>=1;return res; }
Part 3
区间操作
询问区间和,把[s,t]闭区间换成(s,t)开区间来计算
int Sum(int s,int t,int Ans=0){ for (s=s+M-1,t=t+M+1;s^t^1;s>>=1,t>>=1){ if(~s&1) Ans+=d[s^1]; if( t&1) Ans+=d[t^1]; }return Ans; }
- 为什么
~s&1
? -
为什么
t&1
?
变成开区间了以后,如果s是左儿子,那么它的兄弟节点一定在区间内,同理,如果t是右儿子,那么它的兄弟节点也一定在区间内! -
这样计算不会重复吗?
答案是会的!所以注意迭代的出口s^t^1
如果s,t就是兄弟节点,那么也就迭代完成了。
代码简单,即使背过也不难QuQ
- 区间最小值
void Sum(int s,int t,int L=0,int R=0){ for(s=s+M-1,t=t+M+1;s^t^1;s>>=1,t>>=1){ L+=d[s],R+=d[t]; if(~s&1) L=min(L,d[s^1]); if(t&1) R=min(R,d[t^1]); } int res=min(L,R);while(s) res+=d[s>>=1]; }
差分!
不要忘记最后的统计!
还有就是建树的时候是用的最大值还是最小值,这个一定要注意,影响到差分。
- 区间最大值
void Sum(int s,int t,int L=0,int R=0){ for(s=s+M-1,t=t+M+1;s^t^1;s>>=1,t>>=1){ L+=d[s],R+=d[t]; if(~s&1) L=max(L,d[s^1]); if(t&1) R=max(R,d[t^1]); } int res=max(L,R);while(s) res+=d[s>>=1]; }
同理。
- 区间加法
void Add(int s,int t,int v,int A=0){ for(s=s+M-1,t=t+M+1;s^t^1;s>>=1,t>>=1){ if(~s&1) d[s^1]+=v;if(t&1) d[t^1]+=v; A=min(d[s],d[s^1]);d[s]-=A,d[s^1]-=A,d[s>>1]+=A; A=min(d[t],d[t^1]);d[t]-=A,d[t^1]-=A,d[t>>1]+=A; } while(s) A=min(d[s],d[s^1]),d[s]-=A,d[s^1]-=A,d[s>>=1]+=A; }
同样是差分!差分就是厉害QuQ
zkw线段树小试牛刀(code来自hzwer.com)
#include<cstdio> #include<iostream> #define M 261244 using namespace std; int tr[524289]; void query(int s,int t) { int ans=0; for(s=s+M-1,t=t+M+1;s^t^1;s>>=1,t>>=1) { if(~s&1)ans+=tr[s^1]; if(t&1)ans+=tr[t^1]; } printf("%d ",ans); } void change(int x,int y) { for(tr[x+=M]+=y,x>>=1;x;x>>=1) tr[x]=tr[x<<1]+tr[x<<1|1]; } int main() { int n,m,f,x,y; scanf("%d",&n); for(int i=1;i<=n;i++){scanf("%d",&x);change(i,x);} scanf("%d",&m); for(int i=1;i<=m;i++) { scanf("%d%d%d",&f,&x,&y); if(f==1)change(x,y); else query(x,y); } return 0; }
poj3468(code来自网络)
#include <cstdio> #include <cstring> #include <cctype> #define N ((131072 << 1) + 10) //表示节点个数->不小于区间长度+2的最小2的正整数次幂*2+10 typedef long long LL; inline int getc() { static const int L = 1 << 15; static char buf[L] , *S = buf , *T = buf; if (S == T) { T = (S = buf) + fread(buf , 1 , L , stdin); if (S == T) return EOF; } return *S++; } inline int getint() { static char c; while(!isdigit(c = getc()) && c != '-'); bool sign = (c == '-'); int tmp = sign ? 0 : c - '0'; while(isdigit(c = getc())) tmp = (tmp << 1) + (tmp << 3) + c - '0'; return sign ? -tmp : tmp; } inline char getch() { char c; while((c = getc()) != 'Q' && c != 'C'); return c; } int M; //底层的节点数 int dl[N] , dr[N]; //节点的左右端点 LL sum[N]; //节点的区间和 LL add[N]; //节点的区间加上一个数的标记 #define l(x) (x<<1) //x的左儿子,利用堆的性质 #define r(x) ((x<<1)|1) //x的右儿子,利用堆的性质 void pushdown(int x) { //下传标记 if (add[x]&&x<M) {//如果是叶子节点,显然不用下传标记(别忘了) add[l(x)] += add[x]; sum[l(x)] += add[x] * (dr[l(x)] - dl[l(x)] + 1); add[r(x)] += add[x]; sum[r(x)] += add[x] * (dr[r(x)] - dl[r(x)] + 1); add[x] = 0; } } int stack[20] , top;//栈 void upd(int x) { //下传x至根节点路径上节点的标记(自上而下,用栈实现) top = 0; int tmp = x; for(; tmp ; tmp >>= 1) stack[++top] = tmp; while(top--) pushdown(stack[top]); } LL query(int tl , int tr) { //求和 LL res=0; int insl = 0, insr = 0; //两侧第一个有用节点 for(tl=tl+M-1,tr=tr+M+1;tl^tr^1;tl>>=1,tr>>=1) { if (~tl&1) { if (!insl) upd(insl=tl^1); res+=sum[tl^1]; } if (tr&1) { if(!insr) upd(insr=tl^1) res+=sum[tr^1]; } } return res; } void modify(int tl , int tr , int val) { //修改 int insl = 0, insr = 0; for(tl=tl+M-1,tr=tr+M+1;tl^tr^1;tl>>=1,tr>>=1) { if (~tl&1) { if (!insl) upd(insl=tl^1); add[tl^1]+=val; sum[tl^1]+=(LL)val*(dr[tl^1]-dl[tl^1]+1); } if (tr&1) { if (!insr) upd(insr=tr^1); add[tr^1]+=val; sum[tr^1]+=(LL)val*(dr[tr^1]-dl[tr^1]+1); } } for(insl=insl>>1;insl;insl>>=1) //一路update sum[insl]=sum[l(insl)]+sum[r(insl)]; for(insr=insr>>1;insr;insr>>=1) sum[insr]=sum[l(insr)]+sum[r(insr)]; } inline void swap(int &a , int &b) { int tmp = a; a = b; b = tmp; } int main() { //freopen("tt.in" , "r" , stdin); int n , ask; n = getint(); ask = getint(); int i; for(M = 1 ; M < (n + 2) ; M <<= 1); for(i = 1 ; i <= n ; ++i) sum[M + i] = getint() , dl[M + i] = dr[M + i] = i; //建树 for(i = M - 1; i >= 1 ; --i) { //预处理节点左右端点 sum[i] = sum[l(i)] + sum[r(i)]; dl[i] = dl[l(i)]; dr[i] = dr[r(i)]; } char s; int a , b , x; while(ask--) { s = getch(); if (s == 'Q') { a = getint(); b = getint(); if (a > b) swap(a , b); printf("%lld " , query(a , b)); } else { a = getint(); b = getint(); x = getint(); if (a > b) swap(a , b); modify(a , b , x); } } return 0; }
可持久化线段树版本?!(来自http://blog.csdn.net/forget311300/article/details/44306265)
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <vector> #define mp(x,y) make_pair(x,y) using namespace std; const int N = 100000; const int inf = 0x3f3f3f3f; int a[N + 10]; int b[N + 10]; int M; int lq, rq; vector<pair<int, int> > s[N * 22]; void add(int id, int cur) { cur += M; int lat = 0; if (s[cur].size()) lat = s[cur][s[cur].size() - 1].second; s[cur].push_back(mp(id, ++lat)); for (cur >>= 1; cur; cur >>= 1) { int l = 0; if (s[cur << 1].size()) l = s[cur << 1][s[cur << 1].size() - 1].second; int r = 0; if (s[cur << 1 | 1].size()) r = s[cur << 1 | 1][s[cur << 1 | 1].size() - 1].second; s[cur].push_back(mp(id, l + r)); } } int Q(int id, int k) { if (id >= M) return id - M; int l = id << 1, r = l ^ 1; int ll = lower_bound(s[l].begin(), s[l].end(), mp(lq, inf)) - s[l].begin() - 1; int rr = lower_bound(s[l].begin(), s[l].end(), mp(rq, inf)) - s[l].begin() - 1; int kk = 0; if (rr >= 0)kk = s[l][rr].second; if (ll >= 0)kk = s[l][rr].second - s[l][ll].second; if (kk < k)return Q(r, k - kk); return Q(l, k); } int main() { int n, m; while (~scanf("%d%d", &n, &m)) { for (int i = 0; i < n; i++) { scanf("%d", a + i); b[i] = a[i]; } sort(b, b + n); int nn = unique(b, b + n) - b; for (M = 1; M < nn; M <<= 1); for (int i = 1; i < M + M; i++) { s[i].clear(); //s[i].push_back(mp(0, 0)); } for (int i = 0; i < n; i++) { int id = lower_bound(b, b + nn, a[i]) - b; add(i + 1, id); } while (m--) { int k; scanf("%d %d %d", &lq, &rq, &k); lq--; int x = Q(1, k); printf("%d ", b[x]); } } return 0; }
完全模板?!(来自http://blog.csdn.net/forget311300/article/details/44306265)
const int N = 1e5; struct node { int sum, d, v; int l, r; void init() { d = 0; v = -1; } void cb(node ls, node rs) { sum = ls.sum + rs.sum; l = ls.l, r = rs.r; } int len() { return r - l + 1; } void V(int x) { sum = len() * x; d = 0; v = x; } void D(int x) { sum += len() * x; d += x; } }; struct tree { int m, h; node g[N << 2]; void init(int n) { for (m = h = 1; m < n + 2; m <<= 1, h++); int i = 0; for (; i <= m; i++) { g[i].init(); g[i].sum = 0; } for (; i <= m + n; i++) { g[i].init(); scanf("%d", &g[i].sum); g[i].l = g[i].r = i - m; } for (; i < m + m; i++) { g[i].init(); g[i].sum = 0; g[i].l = g[i].r = i - m; } for (i = m - 1; i > 0; i--) g[i].cb(g[i << 1], g[i << 1 | 1]); } void dn(int x) { for (int i = h - 1; i > 0; i--) { int f = x >> i; if (g[f].v != -1) { g[f << 1].V(g[f].v); g[f << 1 | 1].V(g[f].v); } if (g[f].d) { g[f << 1].D(g[f].d); g[f << 1 | 1].D(g[f].d); } g[f].v = -1; g[f].d = 0; } } void up(int x) { for (x >>= 1; x; x >>= 1) { if (g[x].v != -1)continue; int d = g[x].d; g[x].d = 0; g[x].cb(g[x << 1], g[x << 1 | 1]); g[x].D(d); } } void update(int l, int r, int x, int o) { l += m - 1, r += m + 1; dn(l), dn(r); for (int s = l, t = r; s ^ t ^ 1; s >>= 1, t >>= 1) { if (~s & 1) { if (o) g[s ^ 1].V(x); else g[s ^ 1].D(x); } if (t & 1) { if (o) g[t ^ 1].V(x); else g[t ^ 1].D(x); } } up(l), up(r); } int Q(int l, int r) { int ans = 0; l += m - 1, r += m + 1; dn(l), dn(r); for (int s = l, t = r; s ^ t ^ 1; s >>= 1, t >>= 1) { if (~s & 1)ans += g[s ^ 1].sum; if (t & 1)ans += g[t ^ 1].sum; } return ans; } };
二维情况(来自http://blog.csdn.net/forget311300/article/details/44306265)
#include <cstdio> #include <algorithm> #include <cstring> #include <cmath> #include <vector> #include <iostream> using namespace std; const int W = 1000; int m; struct tree { int d[W << 2]; void o() { for (int i = 1; i < m + m; i++)d[i] = 0; } void Xor(int l, int r) { l += m - 1, r += m + 1; for (int s = l, t = r; s ^ t ^ 1; s >>= 1, t >>= 1) { if (~s & 1)d[s ^ 1] ^= 1; if (t & 1)d[t ^ 1] ^= 1; } } } g[W << 2]; void chu() { for (int i = 1; i < m + m; i++) g[i].o(); } void Xor(int lx, int ly, int rx, int ry) { lx += m - 1, rx += m + 1; for (int s = lx, t = rx; s ^ t ^ 1; s >>= 1, t >>= 1) { if (~s & 1)g[s ^ 1].Xor(ly, ry); if (t & 1)g[t ^ 1].Xor(ly, ry); } } int Q(int x, int y) { int ans = 0; for (int xx = x + m; xx; xx >>= 1) { for (int yy = y + m; yy; yy >>= 1) { ans ^= g[xx].d[yy]; } } return ans; } int main() { int T; cin >> T; int fl = 0; while (T--) { if (fl) { printf(" "); } fl = 1; int N, M; cin >> N >> M; for (m = 1; m < N + 2; m <<= 1); chu(); while (M--) { char o[4]; scanf("%s", o); if (*o == 'Q') { int x, y; scanf("%d%d", &x, &y); printf("%d ", Q(x, y)); } else { int lx, ly, rx, ry; scanf("%d%d%d%d", &lx, &ly, &rx, &ry); Xor(lx, ly, rx, ry); } } } return 0; }
非递归扫描线+离散化?!(来自http://blog.csdn.net/forget311300/article/details/44306265)
#include <algorithm> #include <iostream> #include <cstdio> #include <cstring> #include <vector> #include <cmath> using namespace std; const int N = 111; int n; vector<double> y; struct node { double s; int c; int l, r; void chu(double ss, int cc, int ll, int rr) { s = ss; c = cc; l = ll, r = rr; } double len() { return y[r] - y[l - 1]; } } g[N << 4]; int M; void init(int n) { for (M = 1; M < n + 2; M <<= 1); g[M].chu(0, 0, 1, 1); for (int i = 1; i <= n; i++) g[i + M].chu(0, 0, i, i); for (int i = n + 1; i < M; i++) g[i + M].chu(0, 0, n, n); for (int i = M - 1; i > 0; i--) g[i].chu(0, 0, g[i << 1].l, g[i << 1 | 1].r); } struct line { double x, yl, yr; int d; line() {} line(double x, double yl, double yr, int dd): x(x), yl(yl), yr(yr), d(dd) {} bool operator < (const line &cc)const { return x < cc.x || (x == cc.x && d > cc.d); } }; vector<line>L; void one(int x) { if (x >= M) { g[x].s = g[x].c ? g[x].len() : 0; return; } g[x].s = g[x].c ? g[x].len() : g[x << 1].s + g[x << 1 | 1].s; } void up(int x) { for (; x; x >>= 1) one(x); } void add(int l, int r, int d) { if (l > r)return; l += M - 1, r += M + 1; for (int s = l, t = r; s ^ t ^ 1; s >>= 1, t >>= 1) { if (~s & 1) { g[s ^ 1].c += d; one(s ^ 1); } if (t & 1) { g[t ^ 1].c += d; one(t ^ 1); } } up(l); up(r); } double sol() { y.clear(); L.clear(); for (int i = 0; i < n; i++) { double lx, ly, rx, ry; scanf("%lf %lf %lf %lf", &lx, &ly, &rx, &ry); L.push_back(line(lx, ly, ry, 1)); L.push_back(line(rx, ly, ry, -1)); y.push_back(ly); y.push_back(ry); } sort(y.begin(), y.end()); y.erase(unique(y.begin(), y.end()), y.end()); init(y.size()); sort(L.begin(), L.end()); n = L.size() - 1; double ans = 0; for (int i = 0; i < n; i++) { int l = upper_bound(y.begin(), y.end(), L[i].yl + 1e-8) - y.begin(); int r = upper_bound(y.begin(), y.end(), L[i].yr + 1e-8) - y.begin() - 1; add(l, r, L[i].d); ans += g[1].s * (L[i + 1].x - L[i].x); } return ans; } int main() { int ca = 1; while (cin >> n && n) { printf("Test case #%d Total explored area: %.2f ", ca++, sol()); } return 0; }