Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /*Correctness: 35/35 tests passed
- Memory: 16/16 tests passed
- Timing: 42/42 tests passed
- Aggregate score: 100.00%*/
- import edu.princeton.cs.algs4.In;
- import edu.princeton.cs.algs4.Point2D;
- import edu.princeton.cs.algs4.RectHV;
- import edu.princeton.cs.algs4.SET;
- import edu.princeton.cs.algs4.StdDraw;
- import edu.princeton.cs.algs4.StdOut;
- public class KdTree {
- private Node root;
- private double Xmin;
- private double Xmax;
- private double Ymin;
- private double Ymax;
- private SET<Point2D> points;
- private double[] unitSquare;
- private int count;
- private byte orientation; // default set to zero is vertical
- private byte lineOrientation;
- private Point2D champion;
- private double distChampion;
- // construct an empty Tree of points
- public KdTree() {
- this.Xmin = 0;
- this.Ymin = 0;
- this.Xmax = 1;
- this.Ymax = 1;
- this.unitSquare = new double[]{Xmin, Ymin, Xmax, Ymax};
- this.count = 0;
- }
- // is the tree empty?
- public boolean isEmpty() {
- return size() == 0;
- }
- // number of points in the tree
- public int size() {
- if (null == root) return 0;
- else return size(root);
- }
- private int size(Node x) {
- if (x == null) return 0;
- else return x.size;
- }
- // add the point to the tree (if it is not already in it)
- public void insert(Point2D p) {
- if (null == p) throw new IllegalArgumentException();
- if (!contains(p)) {
- count++;
- double[] xyInsertPoint = new double[]{p.x(), p.y()};
- put(xyInsertPoint);
- }
- }
- private void put(double[] xyInsertPoint) { // add new node to subtree
- root = put(root, unitSquare, xyInsertPoint, orientation);
- }
- // TODO REFACTOR
- // insert point in the tree rooted at node
- private Node put(Node node, double[] rectangle, double[] xyInsertPoint, byte orientation) {
- if (node == null) {
- // reset rectangle values for next point
- Xmin = 0;
- Ymin = 0;
- Xmax = 1;
- Ymax = 1;
- Point2D insertPoint = new Point2D(xyInsertPoint[0], xyInsertPoint[1]);
- return new Node(insertPoint, rectangle, orientation, 1);
- } else {
- if (orientation == 0) { // splitting line is vertical
- orientation = 1; // alternate key
- double xNodePoint = node.point.x();
- if (xyInsertPoint[0] < xNodePoint) { // we go left
- // the maximum x value needs to be updated
- Xmax = Math.min(Xmax, xNodePoint);
- double[] rectArray = new double[]{Xmin, Ymin, Xmax, Ymax};
- node.lb = put(node.lb, rectArray, xyInsertPoint, orientation); // insertPoint is less than node point so we go left
- } else { // we go right and alternate orientation
- //the minimum x value needs to be updated
- Xmin = Math.max(Xmin, xNodePoint);
- double[] rectArray = new double[]{Xmin, Ymin, Xmax, Ymax};
- node.rt = put(node.rt, rectArray, xyInsertPoint, orientation);
- }
- } else { // splitting line is horizontal
- orientation = 0;
- double yNode = node.point.y();
- if (xyInsertPoint[1] < yNode) { // we go bottom and alternate orientation
- // we need to update the ceiling value here so Ymax
- Ymax = Math.min(Ymax, yNode);
- double[] rectArray = new double[]{Xmin, Ymin, Xmax, Ymax};
- node.lb = put(node.lb, rectArray, xyInsertPoint, orientation);
- } else { // we go top and alternate orientation
- // we update the floor value here so Ymin
- Ymin = Math.max(Ymin, yNode);
- double[] rectArray = new double[]{Xmin, Ymin, Xmax, Ymax};
- node.rt = put(node.rt, rectArray, xyInsertPoint, orientation);
- }
- }
- node.size = 1 + size(node.lb) + size(node.rt);
- return node;
- }
- }
- //does the tree contains point p?
- public boolean contains(Point2D queryPoint) {
- if (null == queryPoint) throw new IllegalArgumentException();
- if (root != null) {
- Point2D resultPoint;
- double[] searchPoint = new double[]{queryPoint.x(), queryPoint.y()};
- resultPoint = get(root, searchPoint, orientation); // root is already in so we have something to compare
- if (resultPoint != null) {
- return queryPoint.equals(resultPoint);
- } else return false;
- } else return false;
- }
- // TODO REFACTOR
- // find point in the tree rooted at node
- private Point2D get(Node node, double[] searchedPoint, byte orientation) {
- if (null == node) {
- return null;
- } else {
- double compareX = searchedPoint[0] - node.point.x();
- double compareY = searchedPoint[1] - node.point.y();
- if (orientation == 0) { // vertical axis split
- orientation = 1;
- if (compareX == 0 && compareY == 0) return node.point;
- else if (compareX < 0) {
- return get(node.lb, searchedPoint, orientation);
- } else {
- return get(node.rt, searchedPoint, orientation);
- }
- } else {
- orientation = 0;
- if (compareY == 0 && compareX == 0) return node.point;
- else if (compareY < 0) { // we go left-bottom and alternate orientation
- return get(node.lb, searchedPoint, orientation);
- } else {
- return get(node.rt, searchedPoint, orientation);
- }
- }
- }
- }
- public void draw() {
- if (root == null) throw new IllegalArgumentException();
- Node node = root;
- drawPoint(node);
- }
- private void drawLine(Node node) {
- if (node == null) throw new IllegalArgumentException();
- lineOrientation = node.orientation;
- if (lineOrientation == 0) { // vertical
- StdDraw.setPenColor(StdDraw.RED); // vertical
- StdDraw.setPenRadius();
- StdDraw.line(node.point.x(), node.rectArray[1], node.point.x(), node.rectArray[3]);
- } else if (lineOrientation == 1) { // horizontal
- StdDraw.setPenColor(StdDraw.BLUE);
- StdDraw.setPenRadius();
- StdDraw.line(node.rectArray[0], node.point.y(), node.rectArray[2], node.point.y());
- }
- }
- private void drawPoint(Node node) {
- StdDraw.setPenColor(StdDraw.BLACK);
- StdDraw.setPenRadius(0.01);
- node.point.draw();
- drawLine(node);
- if (node.rt != null) {
- drawPoint(node.rt);
- }
- if (node.lb != null) {
- drawPoint(node.lb);
- } else return;
- }
- //All points that are inside the rectangle (or on the boundary)
- public Iterable<Point2D> range(RectHV rect) {
- if (rect == null) throw new IllegalArgumentException();
- points = new SET<>();
- if (null != root) {
- points = range(root, rect, 0);
- }
- return points;
- }
- private SET<Point2D> range(Node node, RectHV searchRectangle, int orientation) {
- if (searchRectangle.contains(node.point)) {
- points.add(node.point);
- }
- if (orientation == 0) {// does searchRectangle intersects the VERTICAL splitting line segment?
- orientation = 1; // alternate
- double xmin = node.point.x(); // build the vertical splitting line segment
- double ymin = node.rectArray[1];
- double xmax = node.point.x();
- double ymax = node.rectArray[3];
- RectHV splittingLine = new RectHV(xmin, ymin, xmax, ymax); // splitting vertical line of node.point
- if (searchRectangle.intersects(splittingLine)) { // we have to search both subtrees
- if (node.lb != null) {
- range(node.lb, searchRectangle, orientation);
- }
- if (node.rt != null) {
- range(node.rt, searchRectangle, orientation);
- }
- } else { // if not we search only the one subtree where searchRectangle is located
- double dist = node.point.x() - searchRectangle.xmax(); // could be any searchRectangle point
- if (dist > 0) { // go left
- if (node.lb != null) {
- range(node.lb, searchRectangle, orientation);
- }
- }
- if (dist < 0) { // go right
- if (node.rt != null) {
- range(node.rt, searchRectangle, orientation);
- }
- }
- }
- } else { // orientation is 1 = horizontal
- orientation = 0;
- double xmin = node.rectArray[0];
- double ymin = node.point.y();
- double xmax = node.rectArray[2];
- double ymax = node.point.y();
- RectHV splittingLine = new RectHV(xmin, ymin, xmax, ymax); // splitting vertical line of node.point
- if (searchRectangle.intersects(splittingLine)) {
- if (node.lb != null) {
- range(node.lb, searchRectangle, orientation);
- }
- if (node.rt != null) {
- range(node.rt, searchRectangle, orientation);
- }
- } else {
- double dist = node.point.y() - searchRectangle.ymax();
- if (dist < 0) { // go up top
- if (node.rt != null) {
- range(node.rt, searchRectangle, orientation);
- }
- }
- if (dist > 0) {
- if (node.lb != null) {
- range(node.lb, searchRectangle, orientation);
- }
- }
- }
- }
- return points;
- }
- public Point2D nearest(Point2D queryPoint) {
- if (queryPoint == null) throw new IllegalArgumentException();
- Point2D nearestPoint = null;
- if (null != root) {
- nearestPoint = nearest(root, root.point, queryPoint);
- }
- return nearestPoint;
- }
- private Point2D nearest(Node searchNode, Point2D closerPoint, Point2D queryPoint) {
- if (root == null) throw new IllegalArgumentException();
- Point2D nodePoint = searchNode.point;
- double distNodePoint = nodePoint.distanceSquaredTo(queryPoint);
- double distCloserPoint = closerPoint.distanceSquaredTo(queryPoint);
- champion = distCloserPoint < distNodePoint ? closerPoint : nodePoint;
- distChampion = Math.min(distCloserPoint, distNodePoint); //champion.distanceSquaredTo(queryPoint);
- double distRt = -1;
- double distLb = -1;
- if (searchNode.lb != null) {
- distLb = getDist(queryPoint, searchNode.lb);
- }
- if (searchNode.rt != null) { // 2 subtrees to search
- distRt = getDist(queryPoint, searchNode.rt);
- }
- if (searchNode.lb != null && searchNode.rt != null) { // 2 subtrees to search
- if (distLb < distRt) { // we start searching left tree as it is closer to queryPoint
- nearest(searchNode.lb, champion, queryPoint);
- if (distRt < distChampion) { // PRUNING no need to search in rt tree
- nearest(searchNode.rt, champion, queryPoint);
- }
- } else { // start searching right subtree. QueryPoint might be on the splitting line
- nearest(searchNode.rt, champion, queryPoint);
- if (distLb < distChampion) {
- nearest(searchNode.lb, champion, queryPoint);
- }
- }
- } else if (searchNode.lb != null) { // only 1 subtree to search
- if (distLb < distChampion) { // search lb node
- nearest(searchNode.lb, champion, queryPoint);
- }
- } else if (searchNode.rt != null) { // only 1 to search
- if (distRt < distChampion) { // search rt node
- nearest(searchNode.rt, champion, queryPoint);
- }
- }
- return champion;
- }
- private double getDist(Point2D queryPoint, Node node) {
- if (null == node) throw new IllegalArgumentException();
- double xMin;
- double yMin;
- double xMax;
- double yMax;
- RectHV rect;
- double dist;
- xMin = node.rectArray[0];
- yMin = node.rectArray[1];
- xMax = node.rectArray[2];
- yMax = node.rectArray[3];
- rect = new RectHV(xMin, yMin, xMax, yMax);
- dist = rect.distanceSquaredTo(queryPoint);
- return dist;
- }
- private static class Node {
- private Point2D point; // the point as a key to subtree
- private double[] rectArray; // the axis-aligned rectangle corresponding to this node as value to this key
- private Node lb; // the left/bottom subtree as links
- private Node rt; // the right/top subtree as links
- private byte orientation;
- private int size; // number of nodes in subtree rooted here
- public Node(Point2D point, double[] rectArray, byte orientation, int n) {
- this.point = point;
- this.rectArray = rectArray;
- this.orientation = orientation;
- this.size = n;
- }
- }
- public static void main(String[] args) {
- String filename = args[0];
- In in = new In(filename);
- KdTree kdtree = new KdTree();
- /* StdOut.println("before building tree size is :" + kdtree.size());
- StdOut.println("before building tree isEmpty() is :" + kdtree.isEmpty());*/
- while (!in.isEmpty()) {
- double x = in.readDouble();
- double y = in.readDouble();
- Point2D p = new Point2D(x, y);
- /* StdOut.println("before building tree size is :" + kdtree.size());
- StdOut.println("before building tree isEmpty() is :" + kdtree.isEmpty());*/
- kdtree.insert(p);
- }
- /*
- Point2D p0 = new Point2D(0.7, 0.2);
- Point2D p1 = new Point2D(0.5, 0.4);
- Point2D p2 = new Point2D(0.2, 0.3);
- Point2D p3 = new Point2D(0.4, 0.7);
- Point2D p4 = new Point2D(0.9, 0.6);
- Point2D p5 = new Point2D(0.1, 0.2);
- Point2D p6 = new Point2D(0.1, 0.3);
- Point2D p7 = new Point2D(0.0, 0.3);
- Point2D p8 = new Point2D(0.0, 0.1);
- Point2D p9 = new Point2D(0.5, 0.1);
- KdTree tree = new KdTree();
- tree.insert(p0);
- tree.insert(p1);
- tree.insert(p2);
- tree.insert(p3);
- tree.insert(p4);
- tree.insert(p5);
- tree.insert(p6);
- tree.insert(p7);
- tree.insert(p8);
- tree.insert(p9);
- */
- /*
- RectHV rec = new RectHV(0.11, 0.32, 0.16, 0.46);
- StdOut.println("size of tree is :" + kdtree.size());
- SET<Point2D> res = (SET<Point2D>) kdtree.range(rec);
- res.forEach(p -> StdOut.println(" range : " + p.toString()));*/
- kdtree.draw();
- StdDraw.setPenRadius(0.02);
- // draw in blue the nearest neighbor (using kd-tree algorithm)
- Point2D q1 = new Point2D(0.93, 0.5); // 0.378, 0.075); //0.91, 0.41); // 0.417, 0.715); // (0.037, 0.416);
- StdDraw.setPenColor(StdDraw.GREEN);
- q1.draw();
- StdDraw.setPenColor(StdDraw.BLUE);
- Point2D closest = kdtree.nearest(q1);
- closest.draw();
- StdOut.println("closest point is: " + closest.toString());
- /*RectHV rect = new RectHV(0.03, 0.28, 0.59, 0.66); //0.5, 0.73 , 0.54, 0.99);
- StdDraw.setPenColor(StdDraw.GREEN);
- rect.draw();
- StdOut.println("range " + kdtree.range(rect));*/
- }
- }
Add Comment
Please, Sign In to add comment