jules0707

KdTree

Nov 9th, 2020 (edited)
188
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 16.09 KB | None | 0 0
  1. /*Correctness:  35/35 tests passed
  2. Memory:       16/16 tests passed
  3. Timing:       42/42 tests passed
  4. Aggregate score: 100.00%*/
  5.  
  6.  
  7. import edu.princeton.cs.algs4.In;
  8. import edu.princeton.cs.algs4.Point2D;
  9. import edu.princeton.cs.algs4.RectHV;
  10. import edu.princeton.cs.algs4.SET;
  11. import edu.princeton.cs.algs4.StdDraw;
  12. import edu.princeton.cs.algs4.StdOut;
  13.  
  14. public class KdTree {
  15.  
  16.     private Node root;
  17.     private double Xmin;
  18.     private double Xmax;
  19.     private double Ymin;
  20.     private double Ymax;
  21.     private SET<Point2D> points;
  22.     private double[] unitSquare;
  23.     private int count;
  24.     private byte orientation; // default set to zero is vertical
  25.     private byte lineOrientation;
  26.     private Point2D champion;
  27.     private double distChampion;
  28.  
  29.     // construct an empty Tree of points
  30.     public KdTree() {
  31.         this.Xmin = 0;
  32.         this.Ymin = 0;
  33.         this.Xmax = 1;
  34.         this.Ymax = 1;
  35.         this.unitSquare = new double[]{Xmin, Ymin, Xmax, Ymax};
  36.         this.count = 0;
  37.     }
  38.  
  39.     // is the tree empty?
  40.     public boolean isEmpty() {
  41.         return size() == 0;
  42.     }
  43.  
  44.     // number of points in the tree
  45.     public int size() {
  46.         if (null == root) return 0;
  47.         else return size(root);
  48.     }
  49.  
  50.     private int size(Node x) {
  51.         if (x == null) return 0;
  52.         else return x.size;
  53.     }
  54.  
  55.     // add the point to the tree (if it is not already in it)
  56.     public void insert(Point2D p) {
  57.         if (null == p) throw new IllegalArgumentException();
  58.         if (!contains(p)) {
  59.             count++;
  60.             double[] xyInsertPoint = new double[]{p.x(), p.y()};
  61.             put(xyInsertPoint);
  62.         }
  63.     }
  64.  
  65.     private void put(double[] xyInsertPoint) { // add new node to subtree
  66.         root = put(root, unitSquare, xyInsertPoint, orientation);
  67.     }
  68.  
  69.     // TODO REFACTOR
  70.     // insert point in the tree rooted at node
  71.     private Node put(Node node, double[] rectangle, double[] xyInsertPoint, byte orientation) {
  72.         if (node == null) {
  73.             // reset rectangle values for next point
  74.             Xmin = 0;
  75.             Ymin = 0;
  76.             Xmax = 1;
  77.             Ymax = 1;
  78.             Point2D insertPoint = new Point2D(xyInsertPoint[0], xyInsertPoint[1]);
  79.             return new Node(insertPoint, rectangle, orientation, 1);
  80.         } else {
  81.             if (orientation == 0) { // splitting line is vertical
  82.                 orientation = 1; // alternate key
  83.                 double xNodePoint = node.point.x();
  84.                 if (xyInsertPoint[0] < xNodePoint) { // we go left
  85.                     // the maximum x value needs to be updated
  86.                     Xmax = Math.min(Xmax, xNodePoint);
  87.                     double[] rectArray = new double[]{Xmin, Ymin, Xmax, Ymax};
  88.                     node.lb = put(node.lb, rectArray, xyInsertPoint, orientation);  // insertPoint is less than node point so we go left
  89.                 } else { // we go right and alternate orientation
  90.                     //the minimum x value needs to be updated
  91.                     Xmin = Math.max(Xmin, xNodePoint);
  92.                     double[] rectArray = new double[]{Xmin, Ymin, Xmax, Ymax};
  93.                     node.rt = put(node.rt, rectArray, xyInsertPoint, orientation);
  94.                 }
  95.             } else { // splitting line is horizontal
  96.                 orientation = 0;
  97.                 double yNode = node.point.y();
  98.                 if (xyInsertPoint[1] < yNode) { // we go bottom and alternate orientation
  99.                     // we need to update the ceiling value here so Ymax
  100.                     Ymax = Math.min(Ymax, yNode);
  101.                     double[] rectArray = new double[]{Xmin, Ymin, Xmax, Ymax};
  102.                     node.lb = put(node.lb, rectArray, xyInsertPoint, orientation);
  103.                 } else { // we go top and alternate orientation
  104.                     // we update the floor value here so Ymin
  105.                     Ymin = Math.max(Ymin, yNode);
  106.                     double[] rectArray = new double[]{Xmin, Ymin, Xmax, Ymax};
  107.                     node.rt = put(node.rt, rectArray, xyInsertPoint, orientation);
  108.                 }
  109.             }
  110.             node.size = 1 + size(node.lb) + size(node.rt);
  111.             return node;
  112.         }
  113.     }
  114.  
  115.     //does the tree contains point p?
  116.     public boolean contains(Point2D queryPoint) {
  117.         if (null == queryPoint) throw new IllegalArgumentException();
  118.         if (root != null) {
  119.             Point2D resultPoint;
  120.             double[] searchPoint = new double[]{queryPoint.x(), queryPoint.y()};
  121.             resultPoint = get(root, searchPoint, orientation); // root is already in so we have something to compare
  122.             if (resultPoint != null) {
  123.                 return queryPoint.equals(resultPoint);
  124.             } else return false;
  125.         } else return false;
  126.     }
  127.  
  128.     //  TODO  REFACTOR
  129.     // find point in the tree rooted at node
  130.     private Point2D get(Node node, double[] searchedPoint, byte orientation) {
  131.         if (null == node) {
  132.             return null;
  133.         } else {
  134.             double compareX = searchedPoint[0] - node.point.x();
  135.             double compareY = searchedPoint[1] - node.point.y();
  136.  
  137.             if (orientation == 0) { // vertical axis split
  138.                 orientation = 1;
  139.                 if (compareX == 0 && compareY == 0) return node.point;
  140.                 else if (compareX < 0) {
  141.                     return get(node.lb, searchedPoint, orientation);
  142.                 } else {
  143.                     return get(node.rt, searchedPoint, orientation);
  144.                 }
  145.             } else {
  146.                 orientation = 0;
  147.                 if (compareY == 0 && compareX == 0) return node.point;
  148.                 else if (compareY < 0) { // we go left-bottom and alternate orientation
  149.                     return get(node.lb, searchedPoint, orientation);
  150.                 } else {
  151.                     return get(node.rt, searchedPoint, orientation);
  152.                 }
  153.             }
  154.         }
  155.     }
  156.  
  157.     public void draw() {
  158.         if (root == null) throw new IllegalArgumentException();
  159.         Node node = root;
  160.         drawPoint(node);
  161.     }
  162.  
  163.     private void drawLine(Node node) {
  164.         if (node == null) throw new IllegalArgumentException();
  165.         lineOrientation = node.orientation;
  166.         if (lineOrientation == 0) { // vertical
  167.             StdDraw.setPenColor(StdDraw.RED); // vertical
  168.             StdDraw.setPenRadius();
  169.             StdDraw.line(node.point.x(), node.rectArray[1], node.point.x(), node.rectArray[3]);
  170.         } else if (lineOrientation == 1) { // horizontal
  171.             StdDraw.setPenColor(StdDraw.BLUE);
  172.             StdDraw.setPenRadius();
  173.             StdDraw.line(node.rectArray[0], node.point.y(), node.rectArray[2], node.point.y());
  174.         }
  175.     }
  176.  
  177.     private void drawPoint(Node node) {
  178.         StdDraw.setPenColor(StdDraw.BLACK);
  179.         StdDraw.setPenRadius(0.01);
  180.         node.point.draw();
  181.         drawLine(node);
  182.  
  183.         if (node.rt != null) {
  184.             drawPoint(node.rt);
  185.         }
  186.         if (node.lb != null) {
  187.             drawPoint(node.lb);
  188.         } else return;
  189.     }
  190.  
  191.     //All points that are inside the rectangle (or on the boundary)
  192.     public Iterable<Point2D> range(RectHV rect) {
  193.         if (rect == null) throw new IllegalArgumentException();
  194.         points = new SET<>();
  195.         if (null != root) {
  196.             points = range(root, rect, 0);
  197.         }
  198.         return points;
  199.     }
  200.  
  201.     private SET<Point2D> range(Node node, RectHV searchRectangle, int orientation) {
  202.         if (searchRectangle.contains(node.point)) {
  203.             points.add(node.point);
  204.         }
  205.         if (orientation == 0) {// does searchRectangle intersects the VERTICAL splitting line segment?
  206.             orientation = 1; // alternate
  207.             double xmin = node.point.x(); // build the vertical splitting line segment
  208.             double ymin = node.rectArray[1];
  209.             double xmax = node.point.x();
  210.             double ymax = node.rectArray[3];
  211.             RectHV splittingLine = new RectHV(xmin, ymin, xmax, ymax); // splitting vertical line of node.point
  212.             if (searchRectangle.intersects(splittingLine)) { // we have to search both subtrees
  213.                 if (node.lb != null) {
  214.                     range(node.lb, searchRectangle, orientation);
  215.                 }
  216.                 if (node.rt != null) {
  217.                     range(node.rt, searchRectangle, orientation);
  218.                 }
  219.             } else { // if not we search only the one subtree where searchRectangle is located
  220.                 double dist = node.point.x() - searchRectangle.xmax(); //  could be any searchRectangle point
  221.                 if (dist > 0) { // go left
  222.                     if (node.lb != null) {
  223.                         range(node.lb, searchRectangle, orientation);
  224.                     }
  225.                 }
  226.                 if (dist < 0) { // go right
  227.                     if (node.rt != null) {
  228.                         range(node.rt, searchRectangle, orientation);
  229.                     }
  230.                 }
  231.             }
  232.         } else { // orientation is 1 = horizontal
  233.             orientation = 0;
  234.             double xmin = node.rectArray[0];
  235.             double ymin = node.point.y();
  236.             double xmax = node.rectArray[2];
  237.             double ymax = node.point.y();
  238.             RectHV splittingLine = new RectHV(xmin, ymin, xmax, ymax); // splitting vertical line of node.point
  239.             if (searchRectangle.intersects(splittingLine)) {
  240.                 if (node.lb != null) {
  241.                     range(node.lb, searchRectangle, orientation);
  242.                 }
  243.                 if (node.rt != null) {
  244.                     range(node.rt, searchRectangle, orientation);
  245.                 }
  246.             } else {
  247.                 double dist = node.point.y() - searchRectangle.ymax();
  248.                 if (dist < 0) { // go up top
  249.                     if (node.rt != null) {
  250.                         range(node.rt, searchRectangle, orientation);
  251.                     }
  252.                 }
  253.                 if (dist > 0) {
  254.                     if (node.lb != null) {
  255.                         range(node.lb, searchRectangle, orientation);
  256.                     }
  257.                 }
  258.             }
  259.         }
  260.         return points;
  261.     }
  262.  
  263.     public Point2D nearest(Point2D queryPoint) {
  264.         if (queryPoint == null) throw new IllegalArgumentException();
  265.         Point2D nearestPoint = null;
  266.         if (null != root) {
  267.             nearestPoint = nearest(root, root.point, queryPoint);
  268.         }
  269.         return nearestPoint;
  270.     }
  271.  
  272.     private Point2D nearest(Node searchNode, Point2D closerPoint, Point2D queryPoint) {
  273.         if (root == null) throw new IllegalArgumentException();
  274.         Point2D nodePoint = searchNode.point;
  275.         double distNodePoint = nodePoint.distanceSquaredTo(queryPoint);
  276.         double distCloserPoint = closerPoint.distanceSquaredTo(queryPoint);
  277.  
  278.         champion = distCloserPoint < distNodePoint ? closerPoint : nodePoint;
  279.         distChampion = Math.min(distCloserPoint, distNodePoint); //champion.distanceSquaredTo(queryPoint);
  280.  
  281.         double distRt = -1;
  282.         double distLb = -1;
  283.  
  284.         if (searchNode.lb != null) {
  285.             distLb = getDist(queryPoint, searchNode.lb);
  286.         }
  287.  
  288.         if (searchNode.rt != null) { // 2 subtrees to search
  289.             distRt = getDist(queryPoint, searchNode.rt);
  290.         }
  291.  
  292.         if (searchNode.lb != null && searchNode.rt != null) { // 2 subtrees to search
  293.             if (distLb < distRt) { // we start searching left tree as it is closer to queryPoint
  294.                 nearest(searchNode.lb, champion, queryPoint);
  295.                 if (distRt < distChampion) { // PRUNING no need to search in rt tree
  296.                     nearest(searchNode.rt, champion, queryPoint);
  297.                 }
  298.             } else { // start searching right subtree. QueryPoint might be on the splitting line
  299.                 nearest(searchNode.rt, champion, queryPoint);
  300.                 if (distLb < distChampion) {
  301.                     nearest(searchNode.lb, champion, queryPoint);
  302.                 }
  303.             }
  304.         } else if (searchNode.lb != null) { // only 1 subtree to search
  305.             if (distLb < distChampion) { // search lb node
  306.                 nearest(searchNode.lb, champion, queryPoint);
  307.             }
  308.         } else if (searchNode.rt != null) { // only 1 to search
  309.             if (distRt < distChampion) { // search rt node
  310.                 nearest(searchNode.rt, champion, queryPoint);
  311.             }
  312.         }
  313.         return champion;
  314.     }
  315.  
  316.  
  317.     private double getDist(Point2D queryPoint, Node node) {
  318.         if (null == node) throw new IllegalArgumentException();
  319.         double xMin;
  320.         double yMin;
  321.         double xMax;
  322.         double yMax;
  323.         RectHV rect;
  324.         double dist;
  325.         xMin = node.rectArray[0];
  326.         yMin = node.rectArray[1];
  327.         xMax = node.rectArray[2];
  328.         yMax = node.rectArray[3];
  329.         rect = new RectHV(xMin, yMin, xMax, yMax);
  330.         dist = rect.distanceSquaredTo(queryPoint);
  331.         return dist;
  332.     }
  333.  
  334.     private static class Node {
  335.         private Point2D point; // the point as a key to subtree
  336.         private double[] rectArray; // the axis-aligned rectangle corresponding to this node as value to this key
  337.         private Node lb; // the left/bottom subtree as links
  338.         private Node rt; // the right/top subtree as links
  339.         private byte orientation;
  340.         private int size; // number of nodes in subtree rooted here
  341.  
  342.         public Node(Point2D point, double[] rectArray, byte orientation, int n) {
  343.             this.point = point;
  344.             this.rectArray = rectArray;
  345.             this.orientation = orientation;
  346.             this.size = n;
  347.         }
  348.     }
  349.  
  350.     public static void main(String[] args) {
  351.         String filename = args[0];
  352.         In in = new In(filename);
  353.         KdTree kdtree = new KdTree();
  354.        /* StdOut.println("before building tree size is :" + kdtree.size());
  355.         StdOut.println("before building tree isEmpty() is  :" + kdtree.isEmpty());*/
  356.  
  357.         while (!in.isEmpty()) {
  358.             double x = in.readDouble();
  359.             double y = in.readDouble();
  360.             Point2D p = new Point2D(x, y);
  361.            /* StdOut.println("before building tree size is :" + kdtree.size());
  362.             StdOut.println("before building tree isEmpty() is  :" + kdtree.isEmpty());*/
  363.             kdtree.insert(p);
  364.         }
  365.  
  366. /*
  367.         Point2D p0 = new Point2D(0.7, 0.2);
  368.         Point2D p1 = new Point2D(0.5, 0.4);
  369.         Point2D p2 = new Point2D(0.2, 0.3);
  370.         Point2D p3 = new Point2D(0.4, 0.7);
  371.         Point2D p4 = new Point2D(0.9, 0.6);
  372.         Point2D p5 = new Point2D(0.1, 0.2);
  373.         Point2D p6 = new Point2D(0.1, 0.3);
  374.         Point2D p7 = new Point2D(0.0, 0.3);
  375.         Point2D p8 = new Point2D(0.0, 0.1);
  376.         Point2D p9 = new Point2D(0.5, 0.1);
  377.  
  378.         KdTree tree = new KdTree();
  379.  
  380.         tree.insert(p0);
  381.         tree.insert(p1);
  382.         tree.insert(p2);
  383.         tree.insert(p3);
  384.         tree.insert(p4);
  385.         tree.insert(p5);
  386.         tree.insert(p6);
  387.         tree.insert(p7);
  388.         tree.insert(p8);
  389.         tree.insert(p9);
  390. */
  391.  
  392. /*
  393.         RectHV rec = new RectHV(0.11, 0.32, 0.16, 0.46);
  394.         StdOut.println("size of tree is :" + kdtree.size());
  395.         SET<Point2D> res = (SET<Point2D>) kdtree.range(rec);
  396.         res.forEach(p -> StdOut.println(" range : " + p.toString()));*/
  397.  
  398.         kdtree.draw();
  399.         StdDraw.setPenRadius(0.02);
  400.         // draw in blue the nearest neighbor (using kd-tree algorithm)
  401.  
  402.         Point2D q1 = new Point2D(0.93, 0.5); // 0.378, 0.075); //0.91, 0.41); // 0.417, 0.715); // (0.037, 0.416);
  403.         StdDraw.setPenColor(StdDraw.GREEN);
  404.         q1.draw();
  405.         StdDraw.setPenColor(StdDraw.BLUE);
  406.         Point2D closest = kdtree.nearest(q1);
  407.         closest.draw();
  408.         StdOut.println("closest point is: " + closest.toString());
  409.  
  410.         /*RectHV rect = new RectHV(0.03, 0.28, 0.59, 0.66);  //0.5, 0.73 , 0.54, 0.99);
  411.         StdDraw.setPenColor(StdDraw.GREEN);
  412.         rect.draw();
  413.         StdOut.println("range " + kdtree.range(rect));*/
  414.     }
  415. }
  416.  
Add Comment
Please, Sign In to add comment