长春赛的 I 题是主席树,现在稍微的学了一点主席树,也就算入了个门吧= =
简单的来说主席树就是每个节点上面都是一棵线段树,但是这么多线段树会MLE吧?其实我们解决的办法就是有重复的节点给他利用起来,具体见幻神博客。
不妨以1~n上的求任意区间第k小的问题,就是上面博客中所写,我们从1访问到n的预处理中,每一个时间都新建一个线段树,这棵树上记录着已经出现的各个数字,这样我们求[L,R]上的第k小,我们拿R时刻的线段树减去(L-1)时刻的线段树,就是这个区间内需要的线段树,这个线段树上存在的数字其实就是[L,R]上存在的数字,我们在这里寻找我们需要的第k小就可以了。具体实现方法见上面的博客。
我自己的模板如下:
1 #include <stdio.h> 2 #include <algorithm> 3 #include <string.h> 4 #define t_mid (l+r>>1) 5 using namespace std; 6 const int N = 100000 + 5; 7 8 int n,q,tot,sz; 9 int a[N],b[N]; 10 int rt[N*20],sum[N*20],ls[N*20],rs[N*20]; 11 void build(int &o,int l,int r) 12 { 13 o = ++tot; 14 sum[o] = 0; 15 if(l==r) return; 16 build(ls[o],l,t_mid); 17 build(rs[o],t_mid+1,r); 18 } 19 20 void update(int &o,int l,int r,int last,int p) 21 { 22 o = ++tot; 23 ls[o] = ls[last]; 24 rs[o] = rs[last]; 25 sum[o] = sum[last] + 1; 26 if(l==r) return; 27 if(p <= t_mid) update(ls[o],l,t_mid,ls[last],p); 28 else update(rs[o],t_mid+1,r,rs[last],p); 29 } 30 31 int query(int ql,int qr,int l,int r,int k) 32 { 33 if(l==r) return l; 34 int cnt = sum[ls[qr]] - sum[ls[ql]]; 35 if(cnt >= k) return query(ls[ql],ls[qr],l,t_mid,k); 36 else return query(rs[ql],rs[qr],t_mid+1,r,k-cnt); 37 } 38 39 void work() 40 { 41 int ql,qr,k; 42 scanf("%d%d%d",&ql,&qr,&k); 43 int ans = query(rt[ql-1],rt[qr],1,sz,k); 44 printf("%d ",b[ans]); 45 } 46 47 int main() 48 { 49 while(scanf("%d%d",&n,&q)==2) 50 { 51 tot = 0; 52 for(int i=1;i<=n;i++) scanf("%d",a+i),b[i]=a[i]; 53 sort(b+1,b+1+n); 54 sz = unique(b+1,b+1+n) - (b+1); 55 build(rt[0],1,sz); 56 57 for(int i=1;i<=n;i++) 58 { 59 int t = lower_bound(b+1,b+1+sz,a[i]) - b; 60 update(rt[i],1,sz,rt[i-1],t); 61 } 62 while(q--) work(); 63 } 64 }
然后如果是在一棵树上,求其一条链上的区间第k小呢?其实也差不多,我们就想着怎么把这棵需要的线段树抽取出来就行。这棵树实际上就是 u - lca(u,v) + v - father(lca(u,v))。具体的画画图就可以懂了。这里还涉及到求LCA的方法,具体方法见《挑战程序设计》中的倍增法求LCA即可。
我自己的模板如下:
1 #include <stdio.h> 2 #include <algorithm> 3 #include <string.h> 4 #include <vector> 5 #include <math.h> 6 #define t_mid (l+r>>1) 7 using namespace std; 8 const int N = 100000 + 5; 9 const int MAX_LOG_N = 16 + 5; 10 11 int n,q,tot,sz; 12 int a[N],b[N]; 13 int rt[N*20],sum[N*20],ls[N*20],rs[N*20]; 14 int parent[MAX_LOG_N][N],depth[N]; 15 vector<int> G[N]; 16 17 void getDepth(int v,int p,int d) 18 { 19 parent[0][v] = p; 20 depth[v] = d; 21 for(int i=0;i<G[v].size();i++) 22 { 23 if(G[v][i] != p) getDepth(G[v][i],v,d+1); 24 } 25 } 26 27 void init() 28 { 29 getDepth(1,-1,0); 30 for(int k=0;k+1<MAX_LOG_N;k++) 31 { 32 for(int v=1;v<=n;v++) 33 { 34 if(parent[k][v] < 0) parent[k+1][v] = -1; 35 else parent[k+1][v] = parent[k][parent[k][v]]; 36 } 37 } 38 } 39 40 int lca(int u,int v) 41 { 42 if(depth[u]>depth[v]) swap(u,v); 43 for(int k=0;k<MAX_LOG_N;k++) 44 { 45 if((depth[v]-depth[u]) >> k & 1) 46 { 47 v = parent[k][v]; 48 } 49 } 50 if(u==v) return u; 51 for(int k=MAX_LOG_N-1;k>=0;k--) 52 { 53 if(parent[k][u] != parent[k][v]) 54 { 55 u = parent[k][u]; 56 v = parent[k][v]; 57 } 58 } 59 return parent[0][u]; 60 } 61 62 void build(int &o,int l,int r) 63 { 64 o = ++tot; 65 sum[o] = 0; 66 if(l==r) return; 67 build(ls[o],l,t_mid); 68 build(rs[o],t_mid+1,r); 69 } 70 71 void update(int &o,int l,int r,int last,int p) 72 { 73 o = ++tot; 74 ls[o] = ls[last]; 75 rs[o] = rs[last]; 76 sum[o] = sum[last] + 1; 77 if(l==r) return; 78 if(p <= t_mid) update(ls[o],l,t_mid,ls[last],p); 79 else update(rs[o],t_mid+1,r,rs[last],p); 80 } 81 82 int query(int u,int v,int x,int y,int l,int r,int k) 83 { 84 if(l==r) return l; 85 int cnt = sum[ls[u]] + sum[ls[v]] - sum[ls[x]] - sum[ls[y]]; 86 if(cnt >= k) return query(ls[u],ls[v],ls[x],ls[y],l,t_mid,k); 87 else return query(rs[u],rs[v],rs[x],rs[y],t_mid+1,r,k-cnt); 88 } 89 90 void work() 91 { 92 int u,v,k; 93 scanf("%d%d%d",&u,&v,&k); 94 int _lca = lca(u,v); 95 int _lca_fa = parent[0][_lca]; 96 int ans = query(rt[u],rt[v],rt[_lca],rt[_lca_fa],1,sz,k); 97 printf("%d ",b[ans]); 98 } 99 100 void dfs(int u,int fa) 101 { 102 for(int i=0;i<G[u].size();i++) 103 { 104 int v = G[u][i]; 105 if(v==fa) continue; 106 int t = lower_bound(b+1,b+1+sz,a[v]) - b; 107 update(rt[v],1,sz,rt[u],t); 108 dfs(v,u); 109 } 110 } 111 112 int main() 113 { 114 while(scanf("%d%d",&n,&q)==2) 115 { 116 tot = 0; 117 for(int i=1;i<=n;i++) G[i].clear(); 118 for(int i=1;i<=n;i++) scanf("%d",a+i),b[i]=a[i]; 119 sort(b+1,b+1+n); 120 sz = unique(b+1,b+1+n) - (b+1); 121 for(int i=1;i<n;i++) 122 { 123 int u,v;scanf("%d%d",&u,&v); 124 G[u].push_back(v); 125 G[v].push_back(u); 126 } 127 build(rt[0],1,sz); 128 init(); 129 130 int t = lower_bound(b+1,b+1+sz,a[1]) - b; 131 update(rt[1],1,sz,rt[0],t); 132 dfs(1,-1); 133 134 while(q--) work(); 135 } 136 }
好,接下来就是解决那个烦人的 I 题了。
我们首先需要用主席树来解决区间内不同的数的个数,这东西比较奥义- -直接上模板好了。。反正随便百度一下"主席树求区间内不同数的个数"都会出来spoj的D-query那题,随便看下原理就行= =。。。然后用二分解决 I 题(固定左端点,二分右端点,具体见代码。。)。
看我直接丢 I 题的代码~:
1 #include <stdio.h> 2 #include <algorithm> 3 #include <string.h> 4 #include <map> 5 #define t_mid (l+r>>1) 6 using namespace std; 7 const int N = 2*100000 + 50; 8 9 int rt[N*20*2],sum[N*20*2],ls[N*20*2],rs[N*20*2]; 10 int a[N],n,m,tot; 11 void build(int &o,int l,int r) 12 { 13 o = ++tot; 14 sum[o] = 0; 15 if(l == r) return; 16 build(ls[o],l,t_mid); 17 build(rs[o],t_mid+1,r); 18 } 19 20 void update(int &o,int l,int r,int last,int pos,int dt) 21 { 22 o = ++tot; 23 sum[o] = sum[last]; 24 ls[o] = ls[last]; 25 rs[o] = rs[last]; 26 if(l==r) {sum[o]+=dt;return;} 27 if(pos <= t_mid) update(ls[o],l,t_mid,ls[last],pos,dt); 28 else update(rs[o],t_mid+1,r,rs[last],pos,dt); 29 sum[o] = sum[ls[o]] + sum[rs[o]]; 30 } 31 32 int query(int l,int r,int o,int pos) 33 { 34 if(l == r) return sum[o]; 35 if(pos <= t_mid) return sum[rs[o]] + query(l,t_mid,ls[o],pos); 36 else return query(t_mid+1,r,rs[o],pos); 37 } 38 39 /* 40 int query(int l,int r,int L,int R,int x){ 41 if(L <= l && r <= R) return sum[x]; 42 int mid = (l+r) >> 1 , ret = 0; 43 if(L <= mid) ret += query(l,mid,L,R,ls[x]); 44 if(R > mid) ret += query(mid+1,r,L,R,rs[x]); 45 return ret; 46 } 47 */ 48 49 int main() 50 { 51 int T;scanf("%d",&T); 52 for(int kase=1;kase<=T;kase++) 53 { 54 scanf("%d%d",&n,&m); 55 int pre = 0; 56 map<int,int> mp; 57 tot = 0; 58 for(int i=1;i<=n;i++) scanf("%d",a+i); 59 build(rt[0],1,n); 60 61 for(int i=1;i<=n;i++) 62 { 63 if(mp.find(a[i]) == mp.end()) 64 { 65 mp[a[i]] = i; 66 update(rt[i],1,n,rt[i-1],i,1); 67 } 68 else 69 { 70 int temp = 0; 71 update(temp,1,n,rt[i-1],mp[a[i]],-1); 72 update(rt[i],1,n,temp,i,1); 73 } 74 mp[a[i]] = i; 75 } 76 //scanf("%d",&m); 77 printf("Case #%d:",kase); 78 while(m--) 79 { 80 int ql,qr;scanf("%d%d",&ql,&qr); 81 int L = min((ql+pre)%n+1,(qr+pre)%n+1); 82 int R = max((ql+pre)%n+1,(qr+pre)%n+1); 83 //L = ql, R = qr; 84 int k = (query(1,n,rt[R],L)+1)>>1; 85 int l = L, r = R; 86 //printf("!! %d %d ",L,R); 87 int ans = -1; 88 while(l<=r) 89 { 90 int mid = l + r >> 1; 91 int t = query(1,n,rt[mid],L); 92 //printf("mid is %d %d ",mid,t); 93 if(t < k) l = mid + 1; 94 else 95 { 96 r = mid - 1; 97 ans = mid; 98 } 99 } 100 /*while(l < r) 101 { 102 int mid = l + r >> 1; 103 int t = query(1,n,rt[mid],L); 104 if(t < k) l = mid + 1; 105 else r = mid; 106 }*/ 107 108 printf(" %d",ans); 109 pre = ans; 110 } 111 puts(""); 112 } 113 } 114 115 /* 116 100 117 20 100 118 1 2 3 4 3 2 1 2 4 2 2 3 1 2 3 1 4 4 2 1 119 1 20 120 1 10 121 2 5 122 4 6 123 3 2 124 4 7 125 126 100 127 5 100 128 0 1 0 2 3 129 1 5 130 */ 131 /* 132 #include<iostream> 133 //#include<bits/stdc++.h> 134 #include<cstdio> 135 #include<string> 136 #include<cstring> 137 #include<map> 138 #include<queue> 139 #include<set> 140 #include<stack> 141 #include<ctime> 142 #include<algorithm> 143 #include<cmath> 144 #include<vector> 145 #define showtime fprintf(stderr,"time = %.15f ",clock() / (double)CLOCKS_PER_SEC) 146 //#pragma comment(linker, "/STACK:1024000000,1024000000") 147 using namespace std; 148 typedef long long ll; 149 typedef long long LL; 150 #define MP make_pair 151 #define PII pair<int,int> 152 #define PLI pair<long long ,int> 153 #define PFI pair<double,int> 154 #define PLL pair<ll,ll> 155 #define PB push_back 156 #define F first 157 #define S second 158 #define lson l,mid,rt<<1 159 #define rson mid+1,r,rt<<1|1 160 #define debug cout<<"?????"<<endl; 161 //freopen("1005.in","r",stdin); 162 //freopen("data.out","w",stdout); 163 const int INF = 0x3f3f3f3f; 164 const double eps = 1e-2; 165 const int N = 4e5 + 50 ; 166 const double PI = acos(-1.); 167 const double E = 2.71828182845904523536; 168 const int MOD = 1e9+7; 169 typedef vector<ll> Vec; 170 typedef vector<Vec> Mat; 171 int n,m; 172 struct node{int l,r,sum;}T[N*40]; 173 int a[N],root[N],pre[N],tot; 174 int q,x,y; 175 int ans[N]; 176 vector<int> v; 177 int getid(int x){ return lower_bound(v.begin(),v.end(),x) - v.begin() + 1;} 178 void init(){ 179 tot = 0; 180 memset(root,0,sizeof(root)); 181 memset(pre,-1,sizeof(pre)); 182 v.clear(); 183 } 184 void update(int l,int r,int val,int &x,int y,int pos){ 185 T[++tot] = T[y] , T[tot].sum += val , x = tot; 186 if(l == r) return ; 187 int mid = (l + r) >> 1; 188 if(pos <= mid) update(l,mid,val,T[x].l,T[y].l,pos); 189 else update(mid+1,r,val,T[x].r,T[y].r,pos); 190 } 191 ** 192 * 【x=L,y=R】 不同数字的有多少个 193 * query(1,n,x,y,root[y]); 第y颗树。 194 * 195 int query(int l,int r,int L,int R,int x){ 196 if(L <= l && r <= R) return T[x].sum; 197 int mid = (l+r) >> 1 , ret = 0; 198 if(L <= mid) ret += query(l,mid,L,R,T[x].l); 199 if(R > mid) ret += query(mid+1,r,L,R,T[x].r); 200 return ret; 201 } 202 int main(){ 203 int kase = 1,T; 204 cin >> T; 205 while(T --){ 206 cin >> n >> m; 207 init(); 208 for(int i = 1 ; i <= n ; i ++) scanf("%d",&a[i]) , v.push_back(a[i]); 209 sort(v.begin(),v.end()); 210 v.erase(unique(v.begin(),v.end()),v.end()); 211 for(int i = 1 ; i <= n ; i ++){ 212 int id = getid(a[i]); 213 if(pre[id] == -1){ 214 update(1,n,1,root[i],root[i-1],i); 215 pre[id] = i; 216 }else{ 217 int tmp; 218 update(1,n,-1,tmp,root[i-1],pre[id]); 219 update(1,n,1,root[i],tmp,i); 220 pre[id] = i; 221 } 222 } 223 ans[0] = 0; 224 printf("Case #%d:",kase ++); 225 for(int i = 1 ; i <= m ; i ++){ 226 scanf("%d%d",&x,&y); 227 int l,r; 228 l = min((x+ans[i-1])%n+1,(y+ans[i-1])%n+1); 229 r = max((x+ans[i-1])%n+1,(y+ans[i-1])%n+1); 230 //l = x ; r = y; 231 //printf("%d %d !! ",l,r); 232 int k = (query(1,n,l,r,root[r])+1) / 2; 233 int ll = l , rr = r; 234 while(ll < rr){ 235 int mid = (ll + rr) / 2; 236 int t = query(1,n,l,mid,root[mid]); 237 if(t < k) ll = mid+1; 238 else rr = mid; 239 } 240 printf(" %d",rr); 241 ans[i] = rr; 242 } 243 puts(""); 244 } 245 return 0; 246 } 247 */
有几点想说明的:1.下面注释的是大力的代码,但是超时了,因为他的query方法和我的有点小差别,虽然都能实现需要的功能,但是似乎我的query方法复杂度更小一点(??)。。不过我的也是卡过的,但是我觉得在长春现场赛的话应该能过,感觉HDU的评测机这次有点坑。。2.我的代码本来是WA的,因为数组开小了,我上面的两个代码都是*20的,都没问题,这里必须要开*40的才行,被坑了这一次以后我下次都开大一点的好了,反正*40内存也够用= =。。那么主席树就写到这里好了,以后刷了题目有什么要补充的再补充好了~(话说我的数据结构真的好烂啊,,以后搞splay怎么办啊233。。)