diff --git a/src/lol/algorithm/avl_tree.h b/src/lol/algorithm/avl_tree.h index 3d1d23e2..9fea6d30 100644 --- a/src/lol/algorithm/avl_tree.h +++ b/src/lol/algorithm/avl_tree.h @@ -16,7 +16,6 @@ namespace lol { - template class avl_tree { @@ -38,18 +37,7 @@ public: else { this->m_root->insert_or_update(key, value); - this->m_root->update_balance(key); - tree_node * node = this->m_root->get_unbalanced_parent(key); - - if (node) - { - node->rebalance_children(key); - this->m_root->update_balance(node); - } - else if (this->m_root->get_balance() == -2) - this->m_root = this->m_root->rotate(tree_node::CW); - else if (this->m_root->get_balance() == 2) - this->m_root = this->m_root->rotate(tree_node::CCW); + this->rebalance_if_needed(key); } return true; @@ -63,18 +51,37 @@ public: if (parent) { - parent->delete_child(key); - this->m_root->update_balance(parent); + parent->erase_child(key); + this->rebalance_if_needed(key); } - else if (this->m_root->get_key() == key) + else if (this->m_root->key_equals(key)) { - // TODO + this->m_root = this->m_root->erase_self(); + this->rebalance_if_needed(key); } + + return true; } return false; } + void rebalance_if_needed(K const & key) + { + this->m_root->update_balance(key); + tree_node * node = this->m_root->get_unbalanced_parent(key); + + if (node) + { + node->rebalance_children(key); + this->m_root->update_balance(node); + } + else if (this->m_root->get_balance() == -2) + this->m_root = this->m_root->rotate(tree_node::CW); + else if (this->m_root->get_balance() == 2) + this->m_root = this->m_root->rotate(tree_node::CCW); + } + protected: class tree_node @@ -143,27 +150,13 @@ protected: this->compute_balance(); } - /* Increases stairs according that key is inserted. - * Do not call "increase_path" if key is not already present in the tree. */ - void increase_path(K key) - { - if (key < this->m_key) - { - this->m_child[0]->increase_path(key); - } - if (this->m_key < key) - { - this->m_child[1]->increase_path(key); - } - } - /* Retrieve the parent of the deeper unbalanced node after key insertion. * Do not call "get_unbalanced_parent" if key is not already present in the tree. */ tree_node * get_unbalanced_parent(K const & key) { tree_node * parent = nullptr; - if (key < this->m_key) + if (key < this->m_key && this->m_child[0]) { parent = this->m_child[0]->get_unbalanced_parent(key); @@ -172,7 +165,7 @@ protected: else if (abs(this->m_child[0]->get_balance()) == 2) return this; } - if (this->m_key < key) + if (this->m_key < key && this->m_child[1]) { parent = this->m_child[1]->get_unbalanced_parent(key); @@ -269,61 +262,96 @@ protected: this->m_stairs[1] = this->m_child[1] ? lol::max(this->m_child[1]->m_stairs[0], this->m_child[1]->m_stairs[1]) + 1 : 0; } - bool is_equal(K const & key) + bool key_equals(K const & key) { return !(key < this->m_key) && !(this->m_key < key); } + /* Retrieve the parent of a key. + * Do not call "get_parent" if you’re not sure the key or node is present. */ + tree_node * get_parent(tree_node * node) + { + return this->get_parent(node->m_key); + } + /* Retrieve the parent of an inserted key. * Do not call "get_parent" if key is not already present in the tree. */ tree_node * get_parent(K const & key) { if (key < this->m_key) { - if (this->m_child[0]->is_equal(key)) + if (this->m_child[0]->key_equals(key)) return this; else return this->m_child[0]->get_parent(key); } else if (this->m_key < key) { - if (this->m_child[1]->is_equal(key)) + if (this->m_child[1]->key_equals(key)) return this; else return this->m_child[1]->get_parent(key); } - else - ASSERT(false); // Something went really really bad + + return nullptr; } - void delete_child(tree_node * parent, K const & key) + void erase_child(K const & key) { + tree_node * erase_me = nullptr; + if (key < this->m_key) { - tree_node * child = this->m_child[0]; - - if (this->m_child[0]->get_balance() == -1) - { - //TODO - } - else - { - // TODO - } - - child->m_child[0] = nullptr; - child->m_child[1] = nullptr; - delete child; + erase_me = this->m_child[0]; + this->m_child[0] = erase_me->erase_self(); } else if (this->m_key < key) { - // TODO + erase_me = this->m_child[1]; + this->m_child[1] = erase_me->erase_self(); } else - ASSERT(false) // Do not delete the "this" node here + ASSERT(false) // Do not erase the "this" node here + + delete erase_me; + } + + tree_node * erase_self() + { + tree_node * replacement = nullptr; + + if (this->get_balance() == -1) + { + replacement = this->get_deeper_previous(); + if (replacement) + this->get_parent(replacement)->m_child[1] = replacement->m_child[0]; + } + else // this->get_balance() >= 0 + { + replacement = this->get_deeper_next(); + if (replacement) + this->get_parent(replacement)->m_child[0] = replacement->m_child[1]; + } + + if (replacement) + { + replacement->m_child[0] = this->m_child[0]; + replacement->m_child[1] = this->m_child[1]; + } + + return replacement; + } + + void replace(tree_node * from, tree_node * to) + { + from->m_child[0] = to->m_child[0]; + from->m_child[1] = to->m_child[1]; + + to->m_child[0] = nullptr; + to->m_child[1] = nullptr; } - tree_node * get_previous() + tree_node * get_deeper_previous() { tree_node * previous = this->m_child[0]; @@ -338,7 +366,7 @@ protected: return previous; } - tree_node * get_next() + tree_node * get_deeper_next() { tree_node * next = this->m_child[1]; diff --git a/src/t/algorithm/avl_tree.cpp b/src/t/algorithm/avl_tree.cpp index 098bcceb..84904704 100644 --- a/src/t/algorithm/avl_tree.cpp +++ b/src/t/algorithm/avl_tree.cpp @@ -73,6 +73,26 @@ lolunit_declare_fixture(AvlTreeTest) lolunit_assert_equal(tree.insert(13, 1), true); lolunit_assert_equal(tree.get_root_balance(), 1); } + + lolunit_declare_test(AvlTreeDeletion) + { + test_tree tree; + + lolunit_assert_equal(tree.insert(10, 1), true); + lolunit_assert_equal(tree.get_root_balance(), 0); + + lolunit_assert_equal(tree.insert(20, 1), true); + lolunit_assert_equal(tree.get_root_balance(), 1); + + lolunit_assert_equal(tree.insert(30, 1), true); + lolunit_assert_equal(tree.get_root_balance(), 0); + + lolunit_assert_equal(tree.erase(30), true); + lolunit_assert_equal(tree.insert(30, 1), true); + + lolunit_assert_equal(tree.erase(20), true); + lolunit_assert_equal(tree.insert(20, 1), true); + } }; }