Advertisement
Screbber

Untitled

Dec 13th, 2022
247
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.62 KB | None | 0 0
  1. // проверять именно это решение тут источники:
  2. // https://www.geeksforgeeks.org/find-closest-element-binary-search-tree/
  3. // https://habr.com/ru/post/150732/
  4.  
  5. #include <iostream>
  6.  
  7. struct Node {
  8.     int value;
  9.     Node *left;
  10.     Node *right;
  11.     int height;
  12.  
  13.     explicit Node(int value) : value(value) {
  14.         left = right = nullptr;
  15.         height = 1;
  16.     }
  17. };
  18.  
  19. class AVLTree {
  20. public:
  21.     AVLTree() {
  22.         root_ = nullptr;
  23.     }
  24.  
  25.     int getHeight() {
  26.         return getHeight(root_);
  27.     }
  28.  
  29.     void insert(int value) {
  30.         root_ = insertInto(root_, value);
  31.     }
  32.  
  33.     void erase(int value) {
  34.         root_ = remove(root_, value);
  35.     }
  36.  
  37.     int *find(int value) {
  38.         return find(root_, value);
  39.     }
  40.  
  41.     int *traversal() {
  42.         size_ = 0;
  43.         size_t size = getSize(root_);
  44.         data_ = new int[size];
  45.         traversal(root_);
  46.         return data_;
  47.     }
  48.  
  49.     int *lowerBound(int value) {
  50.         int min_diff = 10000000, min_diff_key = -1;
  51.         lowerBound(root_, value, min_diff, min_diff_key);
  52.         return find(min_diff_key);
  53.     }
  54.  
  55.     bool empty() {
  56.         return root_ == nullptr;
  57.     }
  58.  
  59.     Node *getRoot() {
  60.         return root_;
  61.     }
  62.  
  63.     int getSize() {
  64.         return getSize(root_);
  65.     }
  66.  
  67.     void print() {
  68.         print(root_);
  69.     }
  70.  
  71.     ~AVLTree() {
  72.         clear(root_);
  73.     }
  74.  
  75. private:
  76.     Node *root_ = nullptr;
  77.     int *data_;
  78.     size_t size_;
  79.  
  80.     int *find(Node *node, int value) {
  81.         if (node != nullptr) {
  82.             if (value < node->value) {
  83.                 return find(node->left, value);
  84.             }
  85.             if (value > node->value) {
  86.                 return find(node->right, value);
  87.             }
  88.             if (node->value == value) {
  89.                 return &node->value;
  90.             }
  91.         }
  92.         return nullptr;
  93.     }
  94.  
  95.     void traversal(Node *node) {
  96.         if (node->left != nullptr) {
  97.             traversal(node->left);
  98.         }
  99.         data_[size_++] = node->value;
  100.         if (node->right != nullptr) {
  101.             traversal(node->right);
  102.         }
  103.     }
  104.  
  105.     void lowerBound(struct Node *node, int lower_value, int &min_diff, int &min_diff_value) {
  106.         if (node == nullptr) {
  107.             return;
  108.         }
  109.         if (node->value == lower_value) {
  110.             min_diff_value = lower_value;
  111.             return;
  112.         }
  113.         if (lower_value <= node->value && min_diff > node->value - lower_value) {
  114.             min_diff = node->value - lower_value;
  115.             min_diff_value = node->value;
  116.         }
  117.         if (lower_value < node->value) {
  118.             lowerBound(node->left, lower_value, min_diff, min_diff_value);
  119.         } else {
  120.             lowerBound(node->right, lower_value, min_diff, min_diff_value);
  121.         }
  122.     }
  123.  
  124.     Node *findMin(Node *node) {
  125.         return node->left ? findMin(node->left) : node;
  126.     }
  127.  
  128.     // удаление узла с минимальным ключом из дерева r
  129.     Node *removeMin(Node *r) {
  130.         if (r->left == nullptr) {
  131.             return r->right;
  132.         }
  133.         r->left = removeMin(r->left);
  134.         return balance(r);
  135.     }
  136.  
  137.     Node *remove(Node *p, int value) {
  138.         if (p == nullptr) {
  139.             return nullptr;
  140.         }
  141.         if (value < p->value) {
  142.             p->left = remove(p->left, value);
  143.         }
  144.         if (value > p->value) {
  145.             p->right = remove(p->right, value);
  146.         }
  147.         if (value == p->value) {
  148.             Node *q = p->left;
  149.             Node *r = p->right;
  150.             delete p;
  151.             if (r == nullptr) {
  152.                 return q;
  153.             }
  154.             Node *min = findMin(r);
  155.             min->right = removeMin(r);
  156.             min->left = q;
  157.             return balance(min);
  158.         }
  159.         return balance(p);
  160.     }
  161.  
  162.     int getHeight(Node *node) {
  163.         int left = 0, right = 0, count = 0;
  164.         if (node != nullptr) {
  165.             left = getHeight(node->left);
  166.             right = getHeight(node->right);
  167.             count = ((left > right) ? left : right) + 1;
  168.         }
  169.         return count;
  170.     }
  171.  
  172.     int getSize(Node *node) {
  173.         int left = 0, right = 0, sum = 0;
  174.         if (node != nullptr) {
  175.             left = getSize(node->left);
  176.             right = getSize(node->right);
  177.             sum = right + left + 1;
  178.         }
  179.         return sum;
  180.     }
  181.  
  182.     int height(Node *node) {
  183.         return node ? node->height : 0;
  184.     }
  185.  
  186.     int balanceFactor(Node *node) {
  187.         return height(node->right) - height(node->left);
  188.     }
  189.  
  190.     void fixHeight(Node *node) {
  191.         int hl = height(node->left);
  192.         int hr = height(node->right);
  193.         node->height = (hl > hr ? hl : hr) + 1;
  194.     }
  195.  
  196.     // Правый поворот относительно p
  197.     Node *rightRotate(Node *p) {
  198.         Node *q = p->left;
  199.         p->left = q->right;
  200.         q->right = p;
  201.         fixHeight(p);
  202.         fixHeight(q);
  203.         return q;
  204.     }
  205.  
  206.     Node *leftRotate(Node *q) {
  207.         Node *p = q->right;
  208.         q->right = p->left;
  209.         p->left = q;
  210.         fixHeight(q);
  211.         fixHeight(p);
  212.         return p;
  213.     }
  214.  
  215.     // балансировка узла p
  216.     Node *balance(Node *p) {
  217.         fixHeight(p);
  218.         // Перекос вправо
  219.         if (balanceFactor(p) == 2) {
  220.             if (balanceFactor(p->right) < 0) {
  221.                 p->right = rightRotate(p->right);
  222.             }
  223.             return leftRotate(p);
  224.         }
  225.         // Перекос влево
  226.         if (balanceFactor(p) == -2) {
  227.             if (balanceFactor(p->left) > 0) {
  228.                 p->left = leftRotate(p->left);
  229.             }
  230.             return rightRotate(p);
  231.         }
  232.         return p;
  233.     }
  234.  
  235.     Node *insertInto(Node *node, int value) {
  236.         if (node == nullptr) {
  237.             return new Node(value);
  238.         }
  239.         if (value < node->value) {
  240.             node->left = insertInto(node->left, value);
  241.         }
  242.         if (value > node->value) {
  243.             node->right = insertInto(node->right, value);
  244.         }
  245.         return balance(node);
  246.     }
  247.  
  248.     void clear(Node *node) {
  249.         if (node != nullptr) {
  250.             if (node->right != nullptr) {
  251.                 clear(node->right);
  252.             }
  253.             if (node->left != nullptr) {
  254.                 clear(node->left);
  255.             }
  256.             delete node;
  257.         }
  258.     }
  259.  
  260.     void print(Node *node) {
  261.         if (node == nullptr) {
  262.             return;
  263.         }
  264.         print(node->left);
  265.         std::cout << node->value << ' ';
  266.         print(node->right);
  267.     }
  268. };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement