diff --git a/src/lol/algorithm/avl_tree.h b/src/lol/algorithm/avl_tree.h index 85a56c4f..88fd1b6e 100644 --- a/src/lol/algorithm/avl_tree.h +++ b/src/lol/algorithm/avl_tree.h @@ -30,20 +30,26 @@ public: { if (!m_root) this->m_root = new tree_node(key, value); + else if (this->m_root->exists(key)) + { + this->m_root->insert_or_update(key, value); + return false; + } else { - tree_node * created = this->m_root->insert(key, value); + this->m_root->insert_or_update(key, value); + this->m_root->update_balance(key); + tree_node * node = this->m_root->get_unbalanced_parent(key); - if (created) + if (node) { - this->m_root->path_update_balance(key); - - tree_node * new_root = this->m_root->path_rebalance(key); - if (new_root) - this->m_root = new_root; + node->rebalance_children(key); + this->m_root->update_balance(node); } - else - return false; + else if (this->m_root->get_balance() == -2) + this->m_root = this->m_root->rotate(tree_node::CW); + else if (this->m_root->get_balance() == 2) + this->m_root = this->m_root->rotate(tree_node::CCW); } return true; @@ -62,72 +68,155 @@ protected: m_stairs[0] = m_stairs[1] = 0; } - tree_node * insert(K const & key, V const & value) + bool exists(K key) { - tree_node * ret = 0; + if (key < this->m_key) + { + if (this->m_child[0]) + return this->m_child[0]->exists(key); + else + return false; + } + if (this->m_key < key) + { + if (this->m_child[1]) + return this->m_child[1]->exists(key); + else + return false; + } + + return true; + } + void insert_or_update(K const & key, V const & value) + { if (key < this->m_key) { if (this->m_child[0]) - ret = this->m_child[0]->insert(key, value); + this->m_child[0]->insert_or_update(key, value); else - ret = this->m_child[0] = new tree_node(key, value); + this->m_child[0] = new tree_node(key, value); } else if (this->m_key < key) { if (this->m_child[1]) - ret = this->m_child[1]->insert(key, value); + this->m_child[1]->insert_or_update(key, value); else - ret = this->m_child[1] = new tree_node(key, value); + this->m_child[1] = new tree_node(key, value); } else this->m_value = value; + } - return ret; + void update_balance(tree_node * node) + { + this->update_balance(node->m_key); } - int path_update_balance(K const & key) + void update_balance(K const & key) { - if (key < this->m_key) - this->m_stairs[0] = lol::max(this->m_child[0]->path_update_balance(key), this->m_stairs[0]); - else if (this->m_key < key) - this->m_stairs[1] = lol::max(this->m_child[1]->path_update_balance(key), this->m_stairs[1]); + if (key < this->m_key && this->m_child[0]) + this->m_child[0]->update_balance(key); + if (this->m_key < key && this->m_child[1]) + this->m_child[1]->update_balance(key); - return lol::max(this->m_stairs[0], this->m_stairs[1]) + 1; + this->compute_balance(); } - tree_node * path_rebalance(K const & key) + /* Increases stairs according that key is inserted. + * Do not call "increase_path" if key is not already present in the tree. */ + void increase_path(K key) { if (key < this->m_key) { - tree_node * node = this->m_child[0]->path_rebalance(key); - if (node) - { - this->m_child[0] = node; - --this->m_stairs[0]; - } + this->m_child[0]->increase_path(key); + this->compute_balance(); } - else if (this->m_key < key) + if (this->m_key < key) { - tree_node * node = this->m_child[1]->path_rebalance(key); - if (node) - { - this->m_child[1] = node; - --this->m_stairs[1]; - } + this->m_child[1]->increase_path(key); + this->compute_balance(); } + } + + /* Retrieve the parent of the deeper unbalanced node after key insertion. + * Do not call "get_unbalanced_parent" if key is not already present in the tree. */ + tree_node * get_unbalanced_parent(K const & key) + { + tree_node * parent = nullptr; - if (this->get_balance() == 2) + if (key < this->m_key) { - return this->rotate(CCW); + parent = this->m_child[0]->get_unbalanced_parent(key); + + if (parent) + return parent; + else if (abs(this->m_child[0]->get_balance()) == 2) + return this; } - else if (this->get_balance() == -2) + if (this->m_key < key) { - return this->rotate(CW); + parent = this->m_child[1]->get_unbalanced_parent(key); + + if (parent) + return parent; + else if (abs(this->m_child[1]->get_balance()) == 2) + return this; } - ASSERT(lol::abs(this->m_stairs[1] - this->m_stairs[0]) < 3); - return 0; + return nullptr; + } + + void rebalance_children(K const & key) + { + if (key < this->m_key) + { + if (this->m_child[0]->get_balance() == 2) + this->rotateLL(); + if (this->m_child[0]->get_balance() == -2) + this->rotateLR(); + } + else if (this->m_key < key) + { + if (this->m_child[1]->get_balance() == 2) + this->rotateRL(); + if (this->m_child[1]->get_balance() == -2) + this->rotateRR(); + } + else + ASSERT(false) // Do not rebalance the "this" node here + } + + void rotateLL() + { + tree_node * newhead = this->m_child[0]->rotate(CCW); + + this->m_child[0] = newhead; + this->compute_balance(); + } + + void rotateLR() + { + tree_node * newhead = this->m_child[0]->rotate(CW); + + this->m_child[0] = newhead; + this->compute_balance(); + } + + void rotateRL() + { + tree_node * newhead = this->m_child[1]->rotate(CCW); + + this->m_child[1] = newhead; + this->compute_balance(); + } + + void rotateRR() + { + tree_node * newhead = this->m_child[1]->rotate(CW); + + this->m_child[1] = newhead; + this->compute_balance(); } enum Rotation { CW = 0, CCW = 1 }; @@ -175,6 +264,11 @@ protected: return this->m_stairs[1] - this->m_stairs[0]; } + K get_key() + { + return this->m_key; + } + protected: K m_key; diff --git a/src/t/algorithm/avl_tree.cpp b/src/t/algorithm/avl_tree.cpp index 42d10b5e..1a674351 100644 --- a/src/t/algorithm/avl_tree.cpp +++ b/src/t/algorithm/avl_tree.cpp @@ -12,8 +12,6 @@ #include -#include - #include namespace lol