题目地址:http://coursera.cs.princeton.edu/algs4/assignments/kdtree.html
分析:
Brute-force implementation. 蛮力实现的方法比较简单,就是逐个遍历每个point进行比较,实现下述API就可以了,没有什么难度。
1 import java.util.ArrayList; 2 import java.util.TreeSet; 3 import edu.princeton.cs.algs4.Point2D; 4 import edu.princeton.cs.algs4.RectHV; 5 import edu.princeton.cs.algs4.StdDraw; 6 /** 7 * @author evasean www.cnblogs.com/evasean/ 8 */ 9 public class PointSET { 10 private TreeSet<Point2D> points; 11 public PointSET() { 12 // construct an empty set of points 13 points = new TreeSet<Point2D>(); 14 } 15 16 public boolean isEmpty() { 17 // is the set empty? 18 return points.isEmpty(); 19 } 20 21 public int size() { 22 // number of points in the set 23 return points.size(); 24 } 25 26 public void insert(Point2D p) { 27 // add the point to the set (if it is not already in the set) 28 if(p==null) 29 throw new IllegalArgumentException("Point2D p is not illegal!"); 30 if(!points.contains(p)) 31 points.add(p); 32 } 33 34 public boolean contains(Point2D p) { 35 // does the set contain point p? 36 if(p==null) 37 throw new IllegalArgumentException("Point2D p is not illegal!"); 38 return points.contains(p); 39 } 40 41 public void draw() { 42 // draw all points to standard draw 43 for (Point2D p : points) { 44 p.draw(); 45 } 46 StdDraw.show(); 47 } 48 49 public Iterable<Point2D> range(RectHV rect) { 50 // all points that are inside the rectangle (or on the boundary) 51 if(rect==null) 52 throw new IllegalArgumentException("RectHV rect is not illegal!"); 53 ArrayList<Point2D> list = new ArrayList<Point2D>(); 54 for(Point2D point : points){ 55 if(rect.contains(point)) list.add(point); 56 } 57 return list; 58 } 59 60 public Point2D nearest(Point2D p) { 61 // a nearest neighbor in the set to point p; null if the set is empty 62 if(p==null) 63 throw new IllegalArgumentException("Point2D p is not illegal!"); 64 if(points.size() == 0) return null; 65 double neareatDistance = Double.POSITIVE_INFINITY; 66 Point2D nearest = null; 67 for(Point2D point : points){ 68 double tmp = p.distanceTo(point); 69 if(Double.compare(neareatDistance, tmp) == 1){ 70 neareatDistance = tmp; 71 nearest = point; 72 } 73 74 } 75 return nearest; 76 } 77 78 public static void main(String[] args) { 79 // unit testing of the methods (optional) 80 } 81 }
2d-tree implementation.
kd-tree插入时要交替以x坐标和y坐标作为判断依据,比如root节点处比较依据为x坐标,那么当要查找或插入一个新节点point时,比较root节点的x坐标和point的x坐标,如果后者比前者小,那么下一次要比较的就是root->left, 相反下一次要比较的就是root->right。进入下一层级之后,就要使用y坐标作为比较依据。示例如下图:
区域搜索:查找落在给定矩阵区域范围内的所有points。从root开始递归查找,如果给定的矩阵不与当前节点的相关矩阵相交,那么就没有必要继续查找该节点及其子树了。
最近节点搜索:查找与给定point距离最近的节点。从root开始递归查找其左右子树,如果给定节点point和已经查找到的最近节点的距离比该point与当前遍历节点的相关矩阵距离还近,那么就没必要遍历这个当前节点及其子树了。
1 import java.util.ArrayList; 2 import edu.princeton.cs.algs4.Point2D; 3 import edu.princeton.cs.algs4.RectHV; 4 import edu.princeton.cs.algs4.StdDraw; 5 /** 6 * @author evasean www.cnblogs.com/evasean/ 7 */ 8 public class KdTree { 9 private Node root; 10 private class Node { 11 private Point2D p; 12 /* 13 * 节点的value就是包含该节点的矩形空间 其左右子树的矩形空间,就是通过该节点进行水平切分或垂直切分的两个子空间 14 */ 15 private RectHV rect; 16 private Node left, right; 17 private int size; 18 private boolean xCoordinate; // 标识该节点是否是以x坐标垂直切分 19 20 public Node(Point2D point, RectHV rect, int size, boolean xCoordinate) { 21 this.p = point; 22 this.rect = rect; 23 this.size = size; 24 this.xCoordinate = xCoordinate; 25 } 26 } 27 28 public KdTree() { 29 // construct an empty set of points 30 } 31 32 public boolean isEmpty() { 33 // is the set empty? 34 return size() == 0; 35 } 36 37 public int size() { 38 // number of points in the set 39 return size(root); 40 } 41 42 private int size(Node x) { 43 if (x == null) 44 return 0; 45 else 46 return x.size; 47 } 48 49 public void insert(Point2D p) { 50 // add the point to the set (if it is not already in the set) 51 if (p == null) 52 throw new IllegalArgumentException("Point2D p is not illegal!"); 53 if (root == null) 54 root = new Node(p, new RectHV(0.0, 0.0, 1.0, 1.0), 1, true); 55 else 56 insert(root, p); 57 // System.out.println("size="+root.size); 58 } 59 60 private void insert(Node x, Point2D p) { 61 if (x.xCoordinate == true) { // x的切分标志是x坐标 62 int cmp = Double.compare(p.x(), x.p.x()); 63 if (cmp == -1) { 64 if (x.left != null) 65 insert(x.left, p); 66 else { 67 RectHV parent = x.rect; 68 // 将节点x的矩形空间进行垂直切分后的左侧部分 69 double newXmin = parent.xmin(); 70 double newYmin = parent.ymin(); 71 double newXmax = x.p.x(); 72 double newYmax = parent.ymax(); 73 x.left = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, false); 74 } 75 } else if (cmp == 1) { 76 if (x.right != null) 77 insert(x.right, p); 78 else { 79 RectHV parent = x.rect; 80 // 将节点x的矩形空间进行垂直切分后的右侧部分 81 double newXmin = x.p.x(); 82 double newYmin = parent.ymin(); 83 double newXmax = parent.xmax(); 84 double newYmax = parent.ymax(); 85 x.right = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, false); 86 } 87 } else { // x.key.x() 与 p.x() 相等 88 int cmp2 = Double.compare(p.y(), x.p.y()); 89 if (cmp2 == -1) { 90 if (x.left != null) 91 insert(x.left, p); 92 else { 93 x.left = new Node(p, x.rect, 1, false); 94 } 95 } else if (cmp2 == 1) { 96 if (x.right != null) 97 insert(x.right, p); 98 else { 99 x.right = new Node(p, x.rect, 1, false); 100 } 101 } 102 } 103 } else { // x的切分标志是y坐标 104 int cmp = Double.compare(p.y(), x.p.y()); 105 if (cmp == -1) { 106 if (x.left != null) 107 insert(x.left, p); 108 else { 109 RectHV parent = x.rect; 110 // 将节点x的矩形空间进行垂直切分后的左侧部分 111 double newXmin = parent.xmin(); 112 double newYmin = parent.ymin(); 113 double newXmax = parent.xmax(); 114 double newYmax = x.p.y(); 115 x.left = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, true); 116 } 117 } else if (cmp == 1) { 118 if (x.right != null) 119 insert(x.right, p); 120 else { 121 RectHV parent = x.rect; 122 // 将节点x的矩形空间进行垂直切分后的左侧部分 123 double newXmin = parent.xmin(); 124 double newYmin = x.p.y(); 125 double newXmax = parent.xmax(); 126 double newYmax = parent.ymax(); 127 x.right = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, true); 128 } 129 } else { // x.key.y() 与 p.y()相等 130 int cmp2 = Double.compare(p.x(), x.p.x()); 131 if (cmp2 == -1) { 132 if (x.left != null) 133 insert(x.left, p); 134 else { 135 x.left = new Node(p, x.rect, 1, true); 136 } 137 } else if (cmp2 == 1) { 138 if (x.right != null) 139 insert(x.right, p); 140 else { 141 x.right = new Node(p, x.rect, 1, true); 142 } 143 } 144 } 145 } 146 x.size = 1 + size(x.left) + size(x.right); 147 } 148 149 public boolean contains(Point2D p) { 150 // does the set contain point p? 151 if (p == null) 152 throw new IllegalArgumentException("Point2D p is not illegal!"); 153 return contains(root, p); 154 } 155 156 private boolean contains(Node x, Point2D p) { 157 if(x == null ) return false; 158 if (x.p.equals(p)) 159 return true; 160 else { 161 if(x.xCoordinate == true){ 162 int cmp = Double.compare(p.x(), x.p.x()); 163 if(cmp == -1) return contains(x.left,p); 164 else if(cmp == 1 ) return contains(x.right,p); 165 else{ 166 int cmp2 = Double.compare(p.y(), x.p.y()); 167 if(cmp2 == -1) return contains(x.left,p); 168 else if(cmp2 == 1 ) return contains(x.right,p); 169 else return true; 170 } 171 }else{ 172 int cmp = Double.compare(p.y(), x.p.y()); 173 if(cmp == -1) return contains(x.left,p); 174 else if(cmp == 1 ) return contains(x.right,p); 175 else{ 176 int cmp2 = Double.compare(p.x(), x.p.x()); 177 if(cmp2 == -1) return contains(x.left,p); 178 else if(cmp2 == 1 ) return contains(x.right,p); 179 else return true; 180 } 181 } 182 } 183 } 184 185 public void draw() { 186 // draw all points to standard draw 187 StdDraw.setXscale(0, 1); 188 StdDraw.setYscale(0, 1); 189 draw(root); 190 } 191 192 private void draw(Node x) { 193 if (x == null) 194 return; 195 StdDraw.setPenColor(StdDraw.BLACK); 196 StdDraw.setPenRadius(0.01); 197 x.p.draw(); 198 if (x.xCoordinate == true) { 199 StdDraw.setPenColor(StdDraw.RED); 200 StdDraw.setPenRadius(); 201 Point2D start = new Point2D(x.p.x(), x.rect.ymin()); 202 Point2D end = new Point2D(x.p.x(), x.rect.ymax()); 203 start.drawTo(end); 204 } else { 205 StdDraw.setPenColor(StdDraw.BLUE); 206 StdDraw.setPenRadius(); 207 Point2D start = new Point2D(x.rect.xmin(), x.p.y()); 208 Point2D end = new Point2D(x.rect.xmax(), x.p.y()); 209 start.drawTo(end); 210 } 211 draw(x.left); 212 draw(x.right); 213 } 214 215 public Iterable<Point2D> range(RectHV rect) { 216 // all points that are inside the rectangle (or on the boundary) 217 if (rect == null) 218 throw new IllegalArgumentException("RectHV rect is not illegal!"); 219 if (root != null) 220 return range(root, rect); 221 else 222 return new ArrayList<Point2D>(); 223 } 224 225 private ArrayList<Point2D> range(Node x, RectHV rect) { 226 ArrayList<Point2D> points = new ArrayList<Point2D>(); 227 if (x.rect.intersects(rect)) { 228 if (rect.contains(x.p)) 229 points.add(x.p); 230 if (x.left != null) 231 points.addAll(range(x.left, rect)); 232 if (x.right != null) 233 points.addAll(range(x.right, rect)); 234 } 235 return points; 236 } 237 238 public Point2D nearest(Point2D p) { 239 // a nearest neighbor in the set to point p; null if the set is empty 240 if (p == null) 241 throw new IllegalArgumentException("Point2D p is not illegal!"); 242 if (root != null) 243 return nearest(root, p, root.p); 244 return null; 245 } 246 247 /** 248 * 作业提交提示nearest的时间复杂度偏高,导致作业只有98分,我觉得这样写比较清晰明了,就懒得继续优化 249 * @param x 250 * @param p 251 * @param currNearPoint 252 * @return 253 */ 254 private Point2D nearest(Node x, Point2D p, Point2D currNearPoint) { 255 if(x.p.equals(p)) return x.p; 256 double currMinDistance = currNearPoint.distanceTo(p); 257 if (Double.compare(x.rect.distanceTo(p), currMinDistance) >= 0) 258 return currNearPoint; 259 else { 260 double distance = x.p.distanceTo(p); 261 if (Double.compare(x.p.distanceTo(p), currMinDistance) == -1) { 262 currNearPoint = x.p; 263 currMinDistance = distance; 264 } 265 if (x.left != null) 266 currNearPoint = nearest(x.left, p, currNearPoint); 267 if (x.right != null) 268 currNearPoint = nearest(x.right, p, currNearPoint); 269 } 270 return currNearPoint; 271 } 272 273 public static void main(String[] args) { 274 // unit testing of the methods (optional) 275 } 276 }