稍微看了一下KD-tree的讲义,大概明白了它的原理,但是实现不出来。。。
所以无耻的抄了一下黄学长的。。。
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 #define ll long long 6 #define inf 1000000000 7 using namespace std; 8 inline int read() 9 { 10 int x=0,f=1;char ch=getchar(); 11 while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} 12 while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} 13 return x*f; 14 } 15 int n,m,root,D; 16 struct node 17 { 18 int d[2],mn[2],mx[2],l,r; 19 int& operator[](int x){return d[x];} 20 node(int x=0,int y=0) 21 { 22 l=0,r=0,d[0]=x,d[1]=y; 23 } 24 }p[500005]; 25 bool operator <(node a,node b) 26 { 27 return a[D]<b[D]; 28 } 29 inline int dis(node a,node b) 30 { 31 return abs(a[0]-b[0])+abs(a[1]-b[1]); 32 } 33 struct kdtree 34 { 35 int ans; 36 node t[1000005],T; 37 void update(int k) 38 { 39 node l=t[t[k].l],r=t[t[k].r]; 40 for(int i=0;i<2;i++) 41 { 42 if(t[k].l)t[k].mn[i]=min(t[k].mn[i],l.mn[i]),t[k].mx[i]=max(t[k].mx[i],l.mx[i]); 43 if(t[k].r)t[k].mn[i]=min(t[k].mn[i],r.mn[i]),t[k].mx[i]=max(t[k].mx[i],r.mx[i]); 44 } 45 } 46 int build(int l,int r,int now) 47 { 48 D=now; 49 int mid=(l+r)>>1; 50 nth_element(p+l,p+mid,p+r+1); 51 t[mid]=p[mid]; 52 for(int i=0;i<2;i++) 53 t[mid].mn[i]=t[mid].mx[i]=t[mid][i]; 54 if(l<mid)t[mid].l=build(l,mid-1,now^1); 55 if(r>mid)t[mid].r=build(mid+1,r,now^1); 56 update(mid); 57 return mid; 58 } 59 void insert(int k,int now) 60 { 61 if(T[now]>=t[k][now]) 62 { 63 if(t[k].r)insert(t[k].r,now^1); 64 else 65 { 66 t[k].r=++n;t[n]=T; 67 for(int i=0;i<2;i++) 68 t[n].mn[i]=t[n].mx[i]=t[n][i]; 69 } 70 } 71 else 72 { 73 if(t[k].l)insert(t[k].l,now^1); 74 else 75 { 76 t[k].l=++n;t[n]=T; 77 for(int i=0;i<2;i++) 78 t[n].mn[i]=t[n].mx[i]=t[n][i]; 79 } 80 } 81 update(k); 82 } 83 int get(int k,node p) 84 { 85 int tmp=0; 86 for(int i=0;i<2;i++) 87 tmp+=max(0,t[k].mn[i]-p[i]); 88 for(int i=0;i<2;i++) 89 tmp+=max(0,p[i]-t[k].mx[i]); 90 return tmp; 91 } 92 void query(int k,int now) 93 { 94 int d,dl=inf,dr=inf; 95 d=dis(t[k],T); 96 ans=min(ans,d); 97 if(t[k].l)dl=get(t[k].l,T); 98 if(t[k].r)dr=get(t[k].r,T); 99 if(dl<dr) 100 { 101 if(dl<ans)query(t[k].l,now^1); 102 if(dr<ans)query(t[k].r,now^1); 103 } 104 else 105 { 106 if(dr<ans)query(t[k].r,now^1); 107 if(dl<ans)query(t[k].l,now^1); 108 } 109 } 110 int query(node p) 111 { 112 ans=inf;T=p;query(root,0); 113 return ans; 114 } 115 void insert(node p) 116 { 117 T=p;insert(root,0); 118 } 119 }kd; 120 int main() 121 { 122 n=read();m=read(); 123 for(int i=1;i<=n;i++)p[i][0]=read(),p[i][1]=read(); 124 root=kd.build(1,n,0); 125 while(m--) 126 { 127 int opt=read(),x=read(),y=read(); 128 if(opt==1)kd.insert(node(x,y)); 129 else printf("%d ",kd.query(node(x,y))); 130 } 131 return 0; 132 }