diff --git a/src/solutions/chapter6/section2/exercise6.py b/src/solutions/chapter6/section2/exercise6.py new file mode 100644 index 0000000..d82d8d2 --- /dev/null +++ b/src/solutions/chapter6/section2/exercise6.py @@ -0,0 +1,27 @@ +from book.chapter6.section1 import left +from book.chapter6.section1 import right + + +def iterative_max_heapify(A, i: int) -> None: + """Restores the max-heap property violated by a single node (an iterative version). + + Implements: + Iterative-Max-Heapify + + Args: + A: the array representing a max-heap in which the max-heap property is violated by a single node + i: the index of the node in A that is not larger than either of its children + """ + while True: + l = left(i) + r = right(i) + if l <= A.heap_size and A[l] > A[i]: + largest = l + else: + largest = i + if r <= A.heap_size and A[r] > A[largest]: + largest = r + if largest == i: + return + A[i], A[largest] = A[largest], A[i] + i = largest diff --git a/test/test_solutions/test_chapter6.py b/test/test_solutions/test_chapter6.py index 37fbb59..fb56ce5 100644 --- a/test/test_solutions/test_chapter6.py +++ b/test/test_solutions/test_chapter6.py @@ -6,6 +6,7 @@ from book.chapter6.section1 import build_max_heap from book.data_structures import Array from solutions.chapter6.section2.exercise3 import min_heapify +from solutions.chapter6.section2.exercise6 import iterative_max_heapify from test_case import ClrsTestCase from test_util import create_array from util import range_of @@ -38,3 +39,20 @@ def test_min_heapify(self, data): self.assertEqual(A.heap_size, n) self.assertMinHeap(A) self.assertArrayPermuted(A, elements, end=n) + + @given(st.data()) + def test_iterative_max_heapify(self, data): + elements = data.draw(lists(integers(), min_size=1)) + n = len(elements) + A = create_array(elements) + build_max_heap(A, n) + new_root = data.draw(integers(max_value=A[1] - 1)) + elements.remove(A[1]) + elements.append(new_root) + A[1] = new_root # possibly violate the max-heap property at the root + + iterative_max_heapify(A, 1) + + self.assertEqual(A.heap_size, n) + self.assertMaxHeap(A) + self.assertArrayPermuted(A, elements, end=n)